From 7005c0c7c1fb420e546c01239a896df13b3da434 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 30 Aug 2018 15:14:19 +0200 Subject: [PATCH 001/163] ENH New implementation of K-means using chunks, speed improvement and change of parallelism level. * Performed on lloyd and elkan algorithms * Use of openmp in cython to get parallelism at chunks level. * Use of scipy shipped cython blas to optimize the computation of pairwise distances. * Deprecate precompute_distances. Distances are now always precomputed chunk by chunk -> low memory usage. * Fix bug: center_shift wrongly computed in elkan * Fix bug: convergence condition too strict in elkan * Fix bug: csr_row_norms returns only np.float64 --- sklearn/cluster/_k_means.pyx | 433 ++++++++-------------- sklearn/cluster/_k_means_elkan.pyx | 473 ++++++++++++++---------- sklearn/cluster/_k_means_lloyd.pyx | 455 +++++++++++++++++++++++ sklearn/cluster/k_means_.py | 348 ++++++++--------- sklearn/cluster/setup.py | 20 +- sklearn/cluster/tests/test_k_means.py | 32 +- sklearn/utils/sparsefuncs_fast.pyx | 21 +- sklearn/utils/tests/test_sparsefuncs.py | 16 +- 8 files changed, 1093 insertions(+), 705 deletions(-) create mode 100644 sklearn/cluster/_k_means_lloyd.pyx diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 66fd620a90cdb..382efa6969666 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -1,4 +1,4 @@ -# cython: profile=True +# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True # Profiling is enabled by default as the overhead does not seem to be # measurable on this specific use case. @@ -8,167 +8,182 @@ # # License: BSD 3 clause -from libc.math cimport sqrt import numpy as np -import scipy.sparse as sp cimport numpy as np cimport cython from cython cimport floating -from sklearn.utils.sparsefuncs_fast import assign_rows_csr + +np.import_array() + ctypedef np.float64_t DOUBLE ctypedef np.int32_t INT -cdef extern from "cblas.h": - double ddot "cblas_ddot"(int N, double *X, int incX, double *Y, int incY) - float sdot "cblas_sdot"(int N, float *X, int incX, float *Y, int incY) -np.import_array() +cpdef floating _inertia_dense(np.ndarray[floating, ndim=2, mode='c'] X, + floating[::1] sample_weight, + floating[:, ::1] centers, + int[::1] labels): + """Compute inertia for dense input data + + Sum of squared distance between each sample and it's assigned center. + """ + cdef: + int n_samples = X.shape[0] + int n_features = X.shape[1] + int i, j, k + floating tmp, sample_inertia + + floating inertia = 0.0 + for i in xrange(n_samples): + j = labels[i] + sample_inertia = 0.0 + for k in xrange(n_features): + tmp = X[i, k] - centers[j, k] + sample_inertia += tmp * tmp + inertia += sample_inertia * sample_weight[i] -@cython.boundscheck(False) -@cython.wraparound(False) -@cython.cdivision(True) -cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X, - np.ndarray[floating, ndim=1] sample_weight, - np.ndarray[floating, ndim=1] x_squared_norms, - np.ndarray[floating, ndim=2] centers, - np.ndarray[INT, ndim=1] labels, - np.ndarray[floating, ndim=1] distances): - """Compute label assignment and inertia for a dense array + return inertia - Return the inertia (sum of squared distances to the centers). + +cpdef floating _inertia_sparse(X, + floating[::1] sample_weight, + floating[:, ::1] centers, + int[::1] labels): + """Compute inertia for sparse input data + + Sum of squared distance between each sample and it's assigned center. """ cdef: - unsigned int n_clusters = centers.shape[0] - unsigned int n_features = centers.shape[1] - unsigned int n_samples = X.shape[0] - unsigned int x_stride - unsigned int center_stride - unsigned int sample_idx, center_idx, feature_idx - unsigned int store_distances = 0 - unsigned int k - np.ndarray[floating, ndim=1] center_squared_norms - # the following variables are always double cause make them floating - # does not save any memory, but makes the code much bigger - DOUBLE inertia = 0.0 - DOUBLE min_dist - DOUBLE dist - - if floating is float: - center_squared_norms = np.zeros(n_clusters, dtype=np.float32) - x_stride = X.strides[1] / sizeof(float) - center_stride = centers.strides[1] / sizeof(float) - dot = sdot - else: - center_squared_norms = np.zeros(n_clusters, dtype=np.float64) - x_stride = X.strides[1] / sizeof(DOUBLE) - center_stride = centers.strides[1] / sizeof(DOUBLE) - dot = ddot - - if n_samples == distances.shape[0]: - store_distances = 1 - - for center_idx in range(n_clusters): - center_squared_norms[center_idx] = dot( - n_features, ¢ers[center_idx, 0], center_stride, - ¢ers[center_idx, 0], center_stride) - - for sample_idx in range(n_samples): - min_dist = -1 - for center_idx in range(n_clusters): - dist = 0.0 - # hardcoded: minimize euclidean distance to cluster center: - # ||a - b||^2 = ||a||^2 + ||b||^2 -2 - dist += dot(n_features, &X[sample_idx, 0], x_stride, - ¢ers[center_idx, 0], center_stride) - dist *= -2 - dist += center_squared_norms[center_idx] - dist += x_squared_norms[sample_idx] - dist *= sample_weight[sample_idx] - if min_dist == -1 or dist < min_dist: - min_dist = dist - labels[sample_idx] = center_idx - - if store_distances: - distances[sample_idx] = min_dist - inertia += min_dist + floating[::1] X_data = X.data + int[::1] X_indices = X.indices + int[::1] X_indptr = X.indptr + + int n_samples = X_indptr.shape[0] - 1 + int n_features = centers.shape[1] + int i, j, k + int row_ptr, nz_len, nz_ptr + floating tmp, sample_inertia + + floating inertia = 0.0 + + for i in xrange(n_samples): + j = labels[i] + sample_inertia = 0.0 + row_ptr = X_indptr[i] + nz_len = X_indptr[i + 1] - X_indptr[i] + nz_ptr = 0 + for k in xrange(n_features): + if nz_ptr < nz_len and k == X_indices[row_ptr + nz_ptr]: + tmp = X_data[row_ptr + nz_ptr] - centers[j, k] + nz_ptr += 1 + else: + tmp = - centers[j, k] + sample_inertia += tmp * tmp + inertia += sample_inertia * sample_weight[i] return inertia -@cython.boundscheck(False) -@cython.wraparound(False) -@cython.cdivision(True) -cpdef DOUBLE _assign_labels_csr(X, np.ndarray[floating, ndim=1] sample_weight, - np.ndarray[DOUBLE, ndim=1] x_squared_norms, - np.ndarray[floating, ndim=2] centers, - np.ndarray[INT, ndim=1] labels, - np.ndarray[floating, ndim=1] distances): - """Compute label assignment and inertia for a CSR input +cpdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] X, + floating[::1] sample_weight, + floating[:, ::1] centers, + floating[::1] weight_in_clusters, + int[::1] labels): + """Relocate centers which have no sample assigned to them""" + cdef: + int[::1] empty_clusters = np.where(np.equal(weight_in_clusters,0))[0].astype(np.int32) + int n_empty = empty_clusters.shape[0] + + if n_empty == 0: + return - Return the inertia (sum of squared distances to the centers). - """ cdef: - np.ndarray[floating, ndim=1] X_data = X.data - np.ndarray[INT, ndim=1] X_indices = X.indices - np.ndarray[INT, ndim=1] X_indptr = X.indptr - unsigned int n_clusters = centers.shape[0] - unsigned int n_features = centers.shape[1] - unsigned int n_samples = X.shape[0] - unsigned int store_distances = 0 - unsigned int sample_idx, center_idx, feature_idx - unsigned int k - np.ndarray[floating, ndim=1] center_squared_norms - # the following variables are always double cause make them floating - # does not save any memory, but makes the code much bigger - DOUBLE inertia = 0.0 - DOUBLE min_dist - DOUBLE dist - - if floating is float: - center_squared_norms = np.zeros(n_clusters, dtype=np.float32) - dot = sdot - else: - center_squared_norms = np.zeros(n_clusters, dtype=np.float64) - dot = ddot - - if n_samples == distances.shape[0]: - store_distances = 1 + int n_features = X.shape[1] - for center_idx in range(n_clusters): - center_squared_norms[center_idx] = dot( - n_features, ¢ers[center_idx, 0], 1, - ¢ers[center_idx, 0], 1) - - for sample_idx in range(n_samples): - min_dist = -1 - for center_idx in range(n_clusters): - dist = 0.0 - # hardcoded: minimize euclidean distance to cluster center: - # ||a - b||^2 = ||a||^2 + ||b||^2 -2 - for k in range(X_indptr[sample_idx], X_indptr[sample_idx + 1]): - dist += centers[center_idx, X_indices[k]] * X_data[k] - dist *= -2 - dist += center_squared_norms[center_idx] - dist += x_squared_norms[sample_idx] - dist *= sample_weight[sample_idx] - if min_dist == -1 or dist < min_dist: - min_dist = dist - labels[sample_idx] = center_idx - if store_distances: - distances[sample_idx] = dist - inertia += min_dist + floating[::1] distances = ((np.asarray(X) - np.asarray(centers)[labels])**2).sum(axis=1) - return inertia + int[::1] far_from_centers = np.argpartition(distances, -n_empty)[-n_empty:].astype(np.int32) + + int new_cluster_id, old_cluster_id, far_idx, idx, k + floating weight + + if n_empty > 0: + for idx in xrange(n_empty): + + new_cluster_id = empty_clusters[idx] + + far_idx = far_from_centers[idx] + weight = sample_weight[far_idx] + + old_cluster_id = labels[far_idx] + + for k in xrange(n_features): + centers[new_cluster_id, k] = X[far_idx, k] * weight + centers[old_cluster_id, k] -= X[far_idx, k] * weight + + weight_in_clusters[new_cluster_id] = weight + weight_in_clusters[old_cluster_id] -= weight + + +cpdef void _relocate_empty_clusters_sparse(floating[::1] X_data, + int[::1] X_indices, + int[::1] X_indptr, + floating[::1] sample_weight, + floating[:, ::1] centers, + floating[::1] weight_in_clusters, + int[::1] labels): + """Relocate centers which have no sample assigned to them""" + cdef: + int[::1] empty_clusters = \ + np.where(np.equal(weight_in_clusters,0))[0].astype(np.int32) + int n_empty = empty_clusters.shape[0] + + if n_empty == 0: + return + + cdef: + int n_samples = X_indptr.shape[0] - 1 + floating x + int i, j, k + + floating[::1] distances = np.zeros(n_samples, dtype=X_data.base.dtype) + + for i in xrange(n_samples): + j = labels[i] + for k in xrange(X_indptr[i], X_indptr[i + 1]): + x = (X_data[k] - centers[j, X_indices[k]]) + distances[i] += x * x + + cdef: + int[::1] far_from_centers = \ + np.argpartition(distances, -n_empty)[-n_empty:].astype(np.int32) + + int new_cluster_id, old_cluster_id, far_idx, idx + floating weight + + if n_empty > 0: + for idx in xrange(n_empty): + + new_cluster_id = empty_clusters[idx] + + far_idx = far_from_centers[idx] + weight = sample_weight[far_idx] + + old_cluster_id = labels[far_idx] + + for k in xrange(X_indptr[far_idx], X_indptr[far_idx + 1]): + centers[new_cluster_id, X_indices[k]] += X_data[k] * weight + centers[old_cluster_id, X_indices[k]] -= X_data[k] * weight + + weight_in_clusters[new_cluster_id] = weight + weight_in_clusters[old_cluster_id] -= weight -@cython.boundscheck(False) -@cython.wraparound(False) -@cython.cdivision(True) def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] sample_weight, - np.ndarray[DOUBLE, ndim=1] x_squared_norms, + np.ndarray[floating, ndim=1] x_squared_norms, np.ndarray[floating, ndim=2] centers, np.ndarray[floating, ndim=1] weight_sums, np.ndarray[INT, ndim=1] nearest_center, @@ -266,149 +281,3 @@ def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] sample_weight, - centers[center_idx, feature_idx]) ** 2 return squared_diff - - -@cython.boundscheck(False) -@cython.wraparound(False) -@cython.cdivision(True) -def _centers_dense(np.ndarray[floating, ndim=2] X, - np.ndarray[floating, ndim=1] sample_weight, - np.ndarray[INT, ndim=1] labels, int n_clusters, - np.ndarray[floating, ndim=1] distances): - """M step of the K-means EM algorithm - - Computation of cluster centers / means. - - Parameters - ---------- - X : array-like, shape (n_samples, n_features) - - sample_weight : array-like, shape (n_samples,) - The weights for each observation in X. - - labels : array of integers, shape (n_samples) - Current label assignment - - n_clusters : int - Number of desired clusters - - distances : array-like, shape (n_samples) - Distance to closest cluster for each sample. - - Returns - ------- - centers : array, shape (n_clusters, n_features) - The resulting centers - """ - ## TODO: add support for CSR input - cdef int n_samples, n_features - n_samples = X.shape[0] - n_features = X.shape[1] - cdef int i, j, c - cdef np.ndarray[floating, ndim=2] centers - cdef np.ndarray[floating, ndim=1] weight_in_cluster - - dtype = np.float32 if floating is float else np.float64 - centers = np.zeros((n_clusters, n_features), dtype=dtype) - weight_in_cluster = np.zeros((n_clusters,), dtype=dtype) - - for i in range(n_samples): - c = labels[i] - weight_in_cluster[c] += sample_weight[i] - empty_clusters = np.where(weight_in_cluster == 0)[0] - # maybe also relocate small clusters? - - if len(empty_clusters): - # find points to reassign empty clusters to - far_from_centers = distances.argsort()[::-1] - - for i, cluster_id in enumerate(empty_clusters): - # XXX two relocated clusters could be close to each other - far_index = far_from_centers[i] - new_center = X[far_index] - centers[cluster_id] = new_center - weight_in_cluster[cluster_id] = sample_weight[far_index] - - for i in range(n_samples): - for j in range(n_features): - centers[labels[i], j] += X[i, j] * sample_weight[i] - - centers /= weight_in_cluster[:, np.newaxis] - - return centers - - -@cython.boundscheck(False) -@cython.wraparound(False) -@cython.cdivision(True) -def _centers_sparse(X, np.ndarray[floating, ndim=1] sample_weight, - np.ndarray[INT, ndim=1] labels, n_clusters, - np.ndarray[floating, ndim=1] distances): - """M step of the K-means EM algorithm - - Computation of cluster centers / means. - - Parameters - ---------- - X : scipy.sparse.csr_matrix, shape (n_samples, n_features) - - sample_weight : array-like, shape (n_samples,) - The weights for each observation in X. - - labels : array of integers, shape (n_samples) - Current label assignment - - n_clusters : int - Number of desired clusters - - distances : array-like, shape (n_samples) - Distance to closest cluster for each sample. - - Returns - ------- - centers : array, shape (n_clusters, n_features) - The resulting centers - """ - cdef int n_samples, n_features - n_samples = X.shape[0] - n_features = X.shape[1] - cdef int curr_label - - cdef np.ndarray[floating, ndim=1] data = X.data - cdef np.ndarray[int, ndim=1] indices = X.indices - cdef np.ndarray[int, ndim=1] indptr = X.indptr - - cdef np.ndarray[floating, ndim=2, mode="c"] centers - cdef np.ndarray[np.npy_intp, ndim=1] far_from_centers - cdef np.ndarray[floating, ndim=1] weight_in_cluster - dtype = np.float32 if floating is float else np.float64 - centers = np.zeros((n_clusters, n_features), dtype=dtype) - weight_in_cluster = np.zeros((n_clusters,), dtype=dtype) - for i in range(n_samples): - c = labels[i] - weight_in_cluster[c] += sample_weight[i] - cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] empty_clusters = \ - np.where(weight_in_cluster == 0)[0] - cdef int n_empty_clusters = empty_clusters.shape[0] - - # maybe also relocate small clusters? - - if n_empty_clusters > 0: - # find points to reassign empty clusters to - far_from_centers = distances.argsort()[::-1][:n_empty_clusters] - - # XXX two relocated clusters could be close to each other - assign_rows_csr(X, far_from_centers, empty_clusters, centers) - - for i in range(n_empty_clusters): - weight_in_cluster[empty_clusters[i]] = 1 - - for i in range(labels.shape[0]): - curr_label = labels[i] - for ind in range(indptr[i], indptr[i + 1]): - j = indices[ind] - centers[curr_label, j] += data[ind] * sample_weight[i] - - centers /= weight_in_cluster[:, np.newaxis] - - return centers diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index a0734a624f14e..ce41c2534e227 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -1,7 +1,4 @@ -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False -# cython: profile=True +# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True # # Author: Andreas Mueller # @@ -10,30 +7,48 @@ import numpy as np cimport numpy as np cimport cython +cimport openmp from cython cimport floating - +from cython.parallel import prange, parallel from libc.math cimport sqrt +from libc.stdlib cimport malloc, free +from libc.string cimport memset, memcpy from ..metrics import euclidean_distances -from ._k_means import _centers_dense +from ._k_means import _relocate_empty_clusters_dense + + +np.import_array() cdef floating euclidean_dist(floating* a, floating* b, int n_features) nogil: - cdef floating result, tmp - result = 0 - cdef int i - for i in range(n_features): - tmp = (a[i] - b[i]) - result += tmp * tmp + """Euclidean distance between a and b, optimized for vectorization""" + cdef: + int i + int n = n_features // 4 + int rem = n_features % 4 + floating result = 0 + + for i in range(n): + result += ((a[0] - b[0]) * (a[0] - b[0]) + +(a[1] - b[1]) * (a[1] - b[1]) + +(a[2] - b[2]) * (a[2] - b[2]) + +(a[3] - b[3]) * (a[3] - b[3])) + a += 4; b += 4 + + for i in range(rem): + result += (a[i] - b[i]) * (a[i] - b[i]) + return sqrt(result) -cdef update_labels_distances_inplace( - floating* X, floating* centers, floating[:, :] center_half_distances, - int[:] labels, floating[:, :] lower_bounds, floating[:] upper_bounds, - int n_samples, int n_features, int n_clusters): - """ - Calculate upper and lower bounds for each sample. +cpdef _init_bounds(np.ndarray[floating, ndim=2, mode='c'] X, + floating[:, ::1] centers, + floating[:, ::1] center_half_distances, + int[::1] labels, + floating[::1] upper_bounds, + floating[:, ::1] lower_bounds): + """Initialize upper and lower bounds for each sample. Given X, centers and the pairwise distances divided by 2.0 between the centers this calculates the upper bounds and lower bounds for each sample. @@ -69,193 +84,275 @@ cdef update_labels_distances_inplace( upper_bounds : nd-array, shape(n_samples,) The distance of each sample from its closest cluster center. This is modified in place by the function. + """ + cdef: + int n_samples = X.shape[0] + int n_clusters = centers.shape[0] + int n_features = X.shape[1] - n_samples : int - The number of samples. + floating min_dist, dist + int best_cluster, i, j - n_features : int - The number of features. + center_half_distances = euclidean_distances(np.asarray(centers)) / 2 - n_clusters : int - The number of clusters. - """ - # assigns closest center to X - # uses triangle inequality - cdef floating* x - cdef floating* c - cdef floating d_c, dist - cdef int c_x, j, sample - for sample in range(n_samples): - # assign first cluster center - c_x = 0 - x = X + sample * n_features - d_c = euclidean_dist(x, centers, n_features) - lower_bounds[sample, 0] = d_c + for i in range(n_samples): + best_cluster = 0 + min_dist = euclidean_dist(&X[i, 0], ¢ers[0, 0], n_features) + lower_bounds[i, 0] = min_dist for j in range(1, n_clusters): - if d_c > center_half_distances[c_x, j]: - c = centers + j * n_features - dist = euclidean_dist(x, c, n_features) - lower_bounds[sample, j] = dist - if dist < d_c: - d_c = dist - c_x = j - labels[sample] = c_x - upper_bounds[sample] = d_c - - -def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, - np.ndarray[floating, ndim=1, mode='c'] sample_weight, - int n_clusters, - np.ndarray[floating, ndim=2, mode='c'] init, - float tol=1e-4, int max_iter=30, verbose=False): - """Run Elkan's k-means. + if min_dist > center_half_distances[best_cluster, j]: + dist = euclidean_dist(&X[i, 0], ¢ers[j, 0], n_features) + lower_bounds[i, j] = dist + if dist < min_dist: + min_dist = dist + best_cluster = j + labels[i] = best_cluster + upper_bounds[i] = min_dist + + +cpdef void _elkan_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, + floating[::1] sample_weight, + floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[::1] weight_in_clusters, + floating[:, ::1] center_half_distances, + floating[::1] distance_next_center, + floating[::1] upper_bounds, + floating[:, ::1] lower_bounds, + int[::1] labels, + int n_jobs = -1, + bint update_centers = True): + """Single interation of K-means elkan algorithm + + Update labels and centers (inplace), for one iteration, distributed + over data chunks. Parameters ---------- - X_ : nd-array, shape (n_samples, n_features) + X : {float32, float64} array-like, shape (n_samples, n_features) + The observations to cluster. - sample_weight : nd-array, shape (n_samples,) + sample_weight : {float32, float64} array-like, shape (n_samples,) The weights for each observation in X. - n_clusters : int - Number of clusters to find. - - init : nd-array, shape (n_clusters, n_features) - Initial position of centers. - - tol : float, default=1e-4 - The relative increment in cluster means before declaring convergence. + centers_old : {float32, float64} array-like, shape (n_clusters, n_features) + Centers before previous iteration, placeholder for the centers after + previous iteration. + + centers_new : {float32, float64} array-like, shape (n_clusters, n_features) + Centers after previous iteration, placeholder for the new centers + computed during this iteration. + + weight_in_clusters : {float32, float64} array-like, shape (n_clusters,) + Placeholder for the sums of the weights of every observation assigned + to each center. + + center_half_distances : {float32, float64} array-like, \ +shape (n_clusters, n_clusters) + Half pairwise distances between centers. + + distance_next_center : {float32, float64} array-like, shape (n_clusters,) + Distance between each center it's closest center. + + upper_bounds : {float32, float64} array-like, shape (n_samples,) + Upper bound for the distance between each sample and it's center, + updated inplace. + + lower_bounds : {float32, float64} array-like, shape (n_samples, n_clusters) + Lower bound for the distance between each sample and each center, + updated inplace. + + labels : int array-like, shape (n_samples,) + labels assignment. + + n_jobs : int + The number of threads to be used by openmp. If -1, openmp will use as + many as possible. + + update_centers : bool + - If True, the labels and the new centers will be computed, i.e. runs + the E-step and the M-step of the algorithm. + - If False, only the labels will be computed, i.e runs the E-step of + the algorithm. + """ + cdef: + int n_samples = X.shape[0] + int n_features = X.shape[1] + int n_clusters = centers_new.shape[0] + + # hard-coded number of samples per chunk. Appeared to be close to + # optimal in all situations. + int n_samples_chunk = 256 if n_samples > 256 else n_samples + int n_chunks = n_samples // n_samples_chunk + int n_samples_r = n_samples % n_samples_chunk + int chunk_idx, n_samples_chunk_eff + int num_threads + + int i, j, k + int label + floating alpha, tmp, x + + floating *centers_new_chunk + floating *weight_in_clusters_chunk + + floating[::1] center_shift = np.zeros(n_clusters, dtype=X.dtype) + + # count remainder chunk in total number of chunks + n_chunks += n_samples != n_chunks * n_samples_chunk + + # re-initialize all arrays at each iteration + if update_centers: + memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) + memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) + memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) + + # compute pairwise distances between centers and get next closest center + distance_next_center = np.partition(np.asarray(center_half_distances), kth=1, axis=0)[1] + + # set number of threads to be used by openmp + num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() + + with nogil, parallel(num_threads=num_threads): + # thread local buffers + centers_new_chunk = malloc(n_clusters * n_features * sizeof(floating)) + weight_in_clusters_chunk = malloc(n_clusters * sizeof(floating)) + # initialize local buffers + memset(centers_new_chunk, 0, n_clusters * n_features * sizeof(floating)) + memset(weight_in_clusters_chunk, 0, n_clusters * sizeof(floating)) + + for chunk_idx in prange(n_chunks): + if n_samples_r > 0 and chunk_idx == n_chunks - 1: + n_samples_chunk_eff = n_samples_r + else: + n_samples_chunk_eff = n_samples_chunk + + _update_chunk( + &X[chunk_idx * n_samples_chunk, 0], + &sample_weight[chunk_idx * n_samples_chunk], + ¢ers_old[0, 0], + centers_new_chunk, + ¢er_half_distances[0, 0], + &distance_next_center[0], + weight_in_clusters_chunk, + &labels[chunk_idx * n_samples_chunk], + &upper_bounds[chunk_idx * n_samples_chunk], + &lower_bounds[chunk_idx * n_samples_chunk, 0], + n_samples_chunk_eff, + n_clusters, + n_features, + update_centers) + + # reduction from local buffers. The gil is necessary for that to avoid + # race conditions. + if update_centers: + with gil: + for j in xrange(n_clusters): + weight_in_clusters[j] += weight_in_clusters_chunk[j] + for k in xrange(n_features): + centers_new[j, k] += centers_new_chunk[j * n_features + k] + + free(weight_in_clusters_chunk) + free(centers_new_chunk) + + if update_centers: + _relocate_empty_clusters_dense(X, sample_weight, centers_new, + weight_in_clusters, labels) + + # average new centers wrt sample weights + for j in xrange(n_clusters): + if weight_in_clusters[j] > 0: + alpha = 1.0 / weight_in_clusters[j] + for k in xrange(n_features): + centers_new[j, k] *= alpha + + # compute shift distance between old and new centers + for j in range(n_clusters): + tmp = 0 + for k in range(n_features): + x = centers_new[j, k] - centers_old[j, k] + tmp += x * x + center_shift[j] = sqrt(tmp) + + # update lower and upper bounds accordingly + for i in range(n_samples): + upper_bounds[i] += center_shift[labels[i]] + + for j in range(n_clusters): + lower_bounds[i, j] -= center_shift[j] + if lower_bounds[i, j] < 0: + lower_bounds[i, j] = 0 + + center_half_distances = euclidean_distances(np.asarray(centers_old)) / 2 + + +cdef void _update_chunk(floating *X, + floating *sample_weight, + floating *centers_old, + floating *centers_new, + floating *center_half_distances, + floating *distance_next_center, + floating *weight_in_clusters, + int *labels, + floating *upper_bounds, + floating *lower_bounds, + int n_samples, + int n_clusters, + int n_features, + bint update_centers) nogil: + """K-means step for one data chunk using elkan algorithm + + Compute the partial contribution of a single data chunk to the labels and + centers. + """ + cdef: + floating upper_bound, distance + int i, j, k, label - max_iter : int, default=30 - Maximum number of iterations of the k-means algorithm. + for i in range(n_samples): + upper_bound = upper_bounds[i] + bounds_tight = 0 + label = labels[i] - verbose : bool, default=False - Whether to be verbose. + # Next center is not far away from the currently assigned center. + # Sample might need to be assigned to another center. + if not distance_next_center[label] >= upper_bound: - """ - if floating is float: - dtype = np.float32 - else: - dtype = np.float64 - - # initialize - cdef np.ndarray[floating, ndim=2, mode='c'] centers_ = init - cdef floating* centers_p = centers_.data - cdef floating* X_p = X_.data - cdef floating* x_p - cdef Py_ssize_t n_samples = X_.shape[0] - cdef Py_ssize_t n_features = X_.shape[1] - cdef int point_index, center_index, label - cdef floating upper_bound, distance - cdef floating[:, :] center_half_distances = euclidean_distances(centers_) / 2. - cdef floating[:, :] lower_bounds = np.zeros((n_samples, n_clusters), dtype=dtype) - cdef floating[:] distance_next_center - labels_ = np.empty(n_samples, dtype=np.int32) - cdef int[:] labels = labels_ - upper_bounds_ = np.empty(n_samples, dtype=dtype) - cdef floating[:] upper_bounds = upper_bounds_ - - # Get the initial set of upper bounds and lower bounds for each sample. - update_labels_distances_inplace(X_p, centers_p, center_half_distances, - labels, lower_bounds, upper_bounds, - n_samples, n_features, n_clusters) - cdef np.uint8_t[:] bounds_tight = np.ones(n_samples, dtype=np.uint8) - cdef np.uint8_t[:] points_to_update = np.zeros(n_samples, dtype=np.uint8) - cdef np.ndarray[floating, ndim=2, mode='c'] new_centers - - if max_iter <= 0: - raise ValueError('Number of iterations should be a positive number' - ', got %d instead' % max_iter) - - col_indices = np.arange(center_half_distances.shape[0], dtype=np.int) - for iteration in range(max_iter): - if verbose: - print("start iteration") - - cd = np.asarray(center_half_distances) - distance_next_center = np.partition(cd, kth=1, axis=0)[1] - - if verbose: - print("done sorting") - - for point_index in range(n_samples): - upper_bound = upper_bounds[point_index] - label = labels[point_index] - - # This means that the next likely center is far away from the - # currently assigned center and the sample is unlikely to be - # reassigned. - if distance_next_center[label] >= upper_bound: - continue - x_p = X_p + point_index * n_features - - # TODO: get pointer to lower_bounds[point_index, center_index] - for center_index in range(n_clusters): + for j in range(n_clusters): # If this holds, then center_index is a good candidate for the # sample to be relabelled, and we need to confirm this by # recomputing the upper and lower bounds. - if (center_index != label - and (upper_bound > lower_bounds[point_index, center_index]) - and (upper_bound > center_half_distances[center_index, label])): - - # Recompute the upper bound by calculating the actual distance - # between the sample and label. - if not bounds_tight[point_index]: - upper_bound = euclidean_dist(x_p, centers_p + label * n_features, n_features) - lower_bounds[point_index, label] = upper_bound - bounds_tight[point_index] = 1 - - # If the condition still holds, then compute the actual distance between - # the sample and center_index. If this is still lesser than the previous - # distance, reassign labels. - if (upper_bound > lower_bounds[point_index, center_index] - or (upper_bound > center_half_distances[label, center_index])): - distance = euclidean_dist(x_p, centers_p + center_index * n_features, n_features) - lower_bounds[point_index, center_index] = distance + if (j != label + and (upper_bound > lower_bounds[i * n_clusters + j]) + and (upper_bound > center_half_distances[label * n_clusters + j])): + + # Recompute upper bound by calculating the actual distance + # between the sample and it's current assigned center. + if not bounds_tight: + upper_bound = euclidean_dist(X + i * n_features, + centers_old + label * n_features, + n_features) + lower_bounds[i * n_clusters + label] = upper_bound + bounds_tight = 1 + + # If the condition still holds, then compute the actual + # distance between the sample and center. If this is less + #than the previous distance, reassign label. + if (upper_bound > lower_bounds[i * n_clusters + j] + or (upper_bound > center_half_distances[label * n_clusters + j])): + + distance = euclidean_dist(X + i * n_features, + centers_old + j * n_features, + n_features) + lower_bounds[i * n_clusters + j] = distance if distance < upper_bound: - label = center_index + label = j upper_bound = distance - labels[point_index] = label - upper_bounds[point_index] = upper_bound - - if verbose: - print("end inner loop") - - # compute new centers - new_centers = _centers_dense(X_, sample_weight, labels_, - n_clusters, upper_bounds_) - bounds_tight[:] = 0 - - # compute distance each center moved - center_shift = np.sqrt(np.sum((centers_ - new_centers) ** 2, axis=1)) - - # update bounds accordingly - lower_bounds = np.maximum(lower_bounds - center_shift, 0) - upper_bounds = upper_bounds + center_shift[labels_] - - # reassign centers - centers_ = new_centers - centers_p = new_centers.data - - # update between-center distances - center_half_distances = euclidean_distances(centers_) / 2. - if verbose: - print('Iteration %i, inertia %s' - % (iteration, np.sum((X_ - centers_[labels]) ** 2 * - sample_weight[:,np.newaxis]))) - center_shift_total = np.sum(center_shift) - if center_shift_total ** 2 < tol: - if verbose: - print("center shift %e within tolerance %e" - % (center_shift_total, tol)) - break - - # We need this to make sure that the labels give the same output as - # predict(X) - if center_shift_total > 0: - update_labels_distances_inplace(X_p, centers_p, center_half_distances, - labels, lower_bounds, upper_bounds, - n_samples, n_features, n_clusters) - return centers_, labels_, iteration + 1 + labels[i] = label + upper_bounds[i] = upper_bound + + if update_centers: + weight_in_clusters[label] += sample_weight[i] + for k in range(n_features): + centers_new[label * n_features + k] += X[i * n_features + k] * sample_weight[i] \ No newline at end of file diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx new file mode 100644 index 0000000000000..c1c1f980be35b --- /dev/null +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -0,0 +1,455 @@ +# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True +# +# Licence: BSD 3 clause + +import numpy as np +cimport numpy as np +cimport cython +cimport openmp +from cython cimport floating +from cython.parallel import prange, parallel +from scipy.linalg.cython_blas cimport sgemm, dgemm +from libc.stdlib cimport malloc, free +from libc.string cimport memset, memcpy + +from ._k_means import (_relocate_empty_clusters_dense, + _relocate_empty_clusters_sparse) + + +np.import_array() + + +cdef: + float MAX_FLT = np.finfo(np.float32).max + double MAX_DBL = np.finfo(np.float64).max + + +cdef void xgemm(char *ta, char *tb, int *m, int *n, int *k, floating *alpha, + floating *A, int *lda, floating *B, int *ldb, floating *beta, + floating *C, int *ldc) nogil: + if floating is float: + sgemm(ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + else: + dgemm(ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) + + +cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, + floating[::1] sample_weight, + floating[::1] x_squared_norms, + floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[::1] centers_squared_norms, + floating[::1] weight_in_clusters, + int[::1] labels, + int n_jobs = -1, + bint update_centers = True): + """Single interation of K-means lloyd algorithm + + Update labels and centers (inplace), for one iteration, distributed + over data chunks. + + Parameters + ---------- + X : {float32, float64} array-like, shape (n_samples, n_features) + The observations to cluster. + + sample_weight : {float32, float64} array-like, shape (n_samples,) + The weights for each observation in X. + + x_squared_norms : {float32, float64} array-like, shape (n_samples,) + Squared L2 norm of X. + + centers_old : {float32, float64} array-like, shape (n_clusters, n_features) + Centers before previous iteration, placeholder for the centers after + previous iteration. + + centers_new : {float32, float64} array-like, shape (n_clusters, n_features) + Centers after previous iteration, placeholder for the new centers + computed during this iteration. + + centers_squared_norms : {float32, float64} array-like, shape (n_clusters,) + Squared L2 norm of the centers. + + weight_in_clusters : {float32, float64} array-like, shape (n_clusters,) + Placeholder for the sums of the weights of every observation assigned + to each center. + + labels : int array-like, shape (n_samples,) + labels assignment. + + n_jobs : int + The number of threads to be used by openmp. If -1, openmp will use as + many as possible. + + update_centers : bool + - If True, the labels and the new centers will be computed, i.e. runs + the E-step and the M-step of the algorithm. + - If False, only the labels will be computed, i.e runs the E-step of + the algorithm. + """ + cdef: + int n_samples = X.shape[0] + int n_features = X.shape[1] + int n_clusters = centers_new.shape[0] + + # hard-coded number of samples per chunk. Appeared to be close to + # optimal in all situations. + int n_samples_chunk = 256 if n_samples > 256 else n_samples + int n_chunks = n_samples // n_samples_chunk + int n_samples_r = n_samples % n_samples_chunk + int chunk_idx, n_samples_chunk_eff + int num_threads + + int j, k + floating alpha + + floating *centers_new_chunk + floating *weight_in_clusters_chunk + floating *pairwise_distances_chunk + + # count remainder chunk in total number of chunks + n_chunks += n_samples != n_chunks * n_samples_chunk + + # re-initialize all arrays at each iteration + memset(¢ers_squared_norms[0], 0, n_clusters * sizeof(floating)) + for j in xrange(n_clusters): + for k in xrange(n_features): + centers_squared_norms[j] += centers_new[j, k] * centers_new[j, k] + + if update_centers: + memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], + n_clusters * n_features * sizeof(floating)) + memset(¢ers_new[0, 0], 0, + n_clusters * n_features * sizeof(floating)) + memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) + + # set number of threads to be used by openmp + num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() + with nogil, parallel(num_threads=num_threads): + centers_new_chunk = \ + malloc(n_clusters * n_features * sizeof(floating)) + + weight_in_clusters_chunk = \ + malloc(n_clusters * sizeof(floating)) + + pairwise_distances_chunk = \ + malloc(n_samples_chunk * n_clusters * sizeof(floating)) + + # initialize local buffers + memset(centers_new_chunk, 0, + n_clusters * n_features * sizeof(floating)) + memset(weight_in_clusters_chunk, 0, n_clusters * sizeof(floating)) + + for chunk_idx in prange(n_chunks): + if n_samples_r > 0 and chunk_idx == n_chunks - 1: + n_samples_chunk_eff = n_samples_r + else: + n_samples_chunk_eff = n_samples_chunk + + _update_chunk_dense( + &X[chunk_idx * n_samples_chunk, 0], + &sample_weight[chunk_idx * n_samples_chunk], + &x_squared_norms[chunk_idx * n_samples_chunk], + ¢ers_old[0, 0], + centers_new_chunk, + ¢ers_squared_norms[0], + weight_in_clusters_chunk, + pairwise_distances_chunk, + &labels[chunk_idx * n_samples_chunk], + n_samples_chunk_eff, + n_clusters, + n_features, + update_centers) + + # reduction from local buffers. The gil is necessary for that to avoid + # race conditions. + if update_centers: + with gil: + for j in xrange(n_clusters): + weight_in_clusters[j] += weight_in_clusters_chunk[j] + for k in xrange(n_features): + centers_new[j, k] += \ + centers_new_chunk[j * n_features + k] + + free(weight_in_clusters_chunk) + free(centers_new_chunk) + free(pairwise_distances_chunk) + + if update_centers: + _relocate_empty_clusters_dense(X, sample_weight, centers_new, + weight_in_clusters, labels) + + # average new centers wrt sample weights + for j in xrange(n_clusters): + if weight_in_clusters[j] > 0: + alpha = 1.0 / weight_in_clusters[j] + for k in xrange(n_features): + centers_new[j, k] *= alpha + + +cdef void _update_chunk_dense(floating *X, + floating *sample_weight, + floating *x_squared_norms, + floating *centers_old, + floating *centers_new, + floating *centers_squared_norms, + floating *weight_in_clusters, + floating *pairwise_distances, + int *labels, + int n_samples, + int n_clusters, + int n_features, + bint update_centers) nogil: + """K-means combined EM step for one data chunk + + Compute the partial contribution of a single data chunk to the labels and + centers. + """ + cdef: + floating sq_dist, min_sq_dist + int i, j, k, best_cluster + + # parameters for the BLAS gemm + floating alpha = -2.0 + floating beta = 1.0 + char *trans_data = 'n' + char *trans_centers = 't' + + # Instead of computing the full pairwise squared distances matrix, + # ||X - C||² = ||X||² - 2 X.C^T + ||C||², we only need to store + # the - 2 X.C^T + ||C||² term since the argmin for a given sample only + # depends on the centers. + for i in xrange(n_samples): + for j in xrange(n_clusters): + pairwise_distances[i * n_clusters + j] = centers_squared_norms[j] + + xgemm(trans_centers, trans_data, &n_clusters, &n_samples, &n_features, + &alpha, centers_old, &n_features, X, &n_features, + &beta, pairwise_distances, &n_clusters) + + for i in xrange(n_samples): + min_sq_dist = pairwise_distances[i * n_clusters] + best_cluster = 0 + for j in xrange(n_clusters): + sq_dist = pairwise_distances[i * n_clusters + j] + if sq_dist < min_sq_dist: + min_sq_dist = sq_dist + best_cluster = j + + labels[i] = best_cluster + + if update_centers: + weight_in_clusters[best_cluster] += sample_weight[i] + for k in xrange(n_features): + centers_new[best_cluster * n_features + k] += \ + X[i * n_features + k] * sample_weight[i] + + +cpdef void _lloyd_iter_chunked_sparse(X, + floating[::1] sample_weight, + floating[::1] x_squared_norms, + floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[::1] centers_squared_norms, + floating[::1] weight_in_clusters, + int[::1] labels, + int n_jobs = -1, + bint update_centers = True): + """Single interation of K-means lloyd algorithm + + Update labels and centers (inplace), for one iteration, distributed + over data chunks. + + Parameters + ---------- + X : {float32, float64} CSR matrix, shape (n_samples, n_features) + The observations to cluster. + + sample_weight : {float32, float64} array-like, shape (n_samples,) + The weights for each observation in X. + + x_squared_norms : {float32, float64} array-like, shape (n_samples,) + Squared L2 norm of X. + + centers_old : {float32, float64} array-like, shape (n_clusters, n_features) + Centers before previous iteration, placeholder for the centers after + previous iteration. + + centers_new : {float32, float64} array-like, shape (n_clusters, n_features) + Centers after previous iteration, placeholder for the new centers + computed during this iteration. + + centers_squared_norms : {float32, float64} array-like, shape (n_clusters,) + Squared L2 norm of the centers. + + weight_in_clusters : {float32, float64} array-like, shape (n_clusters,) + Placeholder for the sums of the weights of every observation assigned + to each center. + + labels : int array-like, shape (n_samples,) + labels assignment. + + n_jobs : int + The number of threads to be used by openmp. If -1, openmp will use as + many as possible. + + update_centers : bool + - If True, the labels and the new centers will be computed. + - If False, only the labels will be computed. + """ + cdef: + int n_samples = X.shape[0] + int n_features = X.shape[1] + int n_clusters = centers_new.shape[0] + + # Chosed same as for dense. Does not have the same impact since with + # sparse data the pairwise distances matrix is not precomputed. + # However, splitting in chunks is necessary to get parallelism. + int n_samples_chunk = 256 if n_samples > 256 else n_samples + int n_chunks = n_samples // n_samples_chunk + int n_samples_r = n_samples % n_samples_chunk + int chunk_idx, n_samples_chunk_eff + int num_threads + + int j, k + floating alpha + + floating[::1] X_data = X.data + int[::1] X_indices = X.indices + int[::1] X_indptr = X.indptr + + floating *centers_new_chunk + floating *weight_in_clusters_chunk + + # count remainder for total number of chunks + n_chunks += n_samples != n_chunks * n_samples_chunk + + # re-initialize all arrays at each iteration + memset(¢ers_squared_norms[0], 0, n_clusters * sizeof(floating)) + for j in xrange(n_clusters): + for k in xrange(n_features): + centers_squared_norms[j] += centers_new[j, k] * centers_new[j, k] + + if update_centers: + memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], + n_clusters * n_features * sizeof(floating)) + memset(¢ers_new[0, 0], 0, + n_clusters * n_features * sizeof(floating)) + memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) + + # set number of threads to be used by openmp + num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() + with nogil, parallel(num_threads=num_threads): + centers_new_chunk = \ + malloc(n_clusters * n_features * sizeof(floating)) + + weight_in_clusters_chunk = \ + malloc(n_clusters * sizeof(floating)) + + # initialize local buffers + memset(centers_new_chunk, 0, + n_clusters * n_features * sizeof(floating)) + memset(weight_in_clusters_chunk, 0, n_clusters * sizeof(floating)) + + for chunk_idx in prange(n_chunks): + if n_samples_r > 0 and chunk_idx == n_chunks - 1: + n_samples_chunk_eff = n_samples_r + else: + n_samples_chunk_eff = n_samples_chunk + + _update_chunk_sparse( + &X_data[X_indptr[chunk_idx * n_samples_chunk]], + &X_indices[X_indptr[chunk_idx * n_samples_chunk]], + &X_indptr[chunk_idx * n_samples_chunk], + &sample_weight[chunk_idx * n_samples_chunk], + &x_squared_norms[chunk_idx * n_samples_chunk], + ¢ers_old[0, 0], + centers_new_chunk, + ¢ers_squared_norms[0], + weight_in_clusters_chunk, + &labels[chunk_idx * n_samples_chunk], + n_samples_chunk_eff, + n_clusters, + n_features, + update_centers) + + # reduction from local buffers. The gil is necessary for that to avoid + # race conditions. + if update_centers: + with gil: + for j in xrange(n_clusters): + weight_in_clusters[j] += weight_in_clusters_chunk[j] + for k in xrange(n_features): + centers_new[j, k] += \ + centers_new_chunk[j * n_features + k] + + free(weight_in_clusters_chunk) + free(centers_new_chunk) + + if update_centers: + _relocate_empty_clusters_sparse(X_data, X_indices, X_indptr, + sample_weight, centers_new, + weight_in_clusters, labels) + + # average new centers wrt sample weights + for j in xrange(n_clusters): + if weight_in_clusters[j] > 0: + alpha = 1.0 / weight_in_clusters[j] + for k in xrange(n_features): + centers_new[j, k] *= alpha + + +cdef void _update_chunk_sparse(floating *X_data, + int *X_indices, + int *X_indptr, + floating *sample_weight, + floating *x_squared_norms, + floating *centers_old, + floating *centers_new, + floating *centers_squared_norms, + floating *weight_in_cluster, + int *labels, + int n_samples, + int n_clusters, + int n_features, + bint update_centers) nogil: + """K-means combined EM step for one data chunk + + Compute the partial contribution of a single data chunk to the labels and + centers. + """ + cdef: + floating sq_dist, min_sq_dist + int i, j, k, best_cluster + floating max_floating = MAX_FLT if floating is float else MAX_DBL + int s = X_indptr[0] + + # XXX Precompute the pairwise distances matrix is not worth for sparse + # currently. Should be tested when BLAS (sparse x dense) matrix + # multiplication is available. + for i in xrange(n_samples): + min_sq_dist = max_floating + best_cluster = 0 + + for j in xrange(n_clusters): + sq_dist = 0.0 + for k in xrange(X_indptr[i] - s, X_indptr[i + 1] - s): + sq_dist += \ + centers_old[j * n_features + X_indices[k]] * X_data[k] + + # Instead of computing the full squared distance with each cluster, + # ||X - C||² = ||X||² - 2 X.C^T + ||C||², we only need to compute + # the - 2 X.C^T + ||C||² term since the argmin for a given sample + # only depends on the centers C. + sq_dist = centers_squared_norms[j] -2 * sq_dist + if sq_dist < min_sq_dist: + min_sq_dist = sq_dist + best_cluster = j + + labels[i] = best_cluster + + if update_centers: + weight_in_cluster[best_cluster] += sample_weight[i] + for k in xrange(X_indptr[i] - s, X_indptr[i + 1] - s): + centers_new[best_cluster * n_features + X_indices[k]] += \ + X_data[k] * sample_weight[i] diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 7cc40722e71f4..f6a71671f908e 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -19,7 +19,6 @@ from ..base import BaseEstimator, ClusterMixin, TransformerMixin from ..metrics.pairwise import euclidean_distances -from ..metrics.pairwise import pairwise_distances_argmin_min from ..utils.extmath import row_norms, squared_norm, stable_cumsum from ..utils.sparsefuncs_fast import assign_rows_csr from ..utils.sparsefuncs import mean_variance_axis @@ -29,12 +28,16 @@ from ..utils import check_random_state from ..utils.validation import check_is_fitted from ..utils.validation import FLOAT_DTYPES -from ..utils._joblib import Parallel -from ..utils._joblib import delayed -from ..utils._joblib import effective_n_jobs +from ..utils import effective_n_jobs +from ..externals.six import string_types from ..exceptions import ConvergenceWarning -from . import _k_means -from ._k_means_elkan import k_means_elkan +from ._k_means import (_inertia_dense, + _inertia_sparse, + _mini_batch_update_csr) +from ._k_means_lloyd import (_lloyd_iter_chunked_dense, + _lloyd_iter_chunked_sparse) +from ._k_means_elkan import (_init_bounds, + _elkan_iter_chunked_dense) ############################################################################### @@ -183,7 +186,7 @@ def _check_sample_weight(X, sample_weight): def k_means(X, n_clusters, sample_weight=None, init='k-means++', - precompute_distances='auto', n_init=10, max_iter=300, + precompute_distances='not-used', n_init=10, max_iter=300, verbose=False, tol=1e-4, random_state=None, copy_x=True, n_jobs=None, algorithm="auto", return_n_iter=False): """K-means clustering algorithm. @@ -231,6 +234,9 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', True : always precompute distances False : never precompute distances + .. deprecated:: 0.21 + 'precompute_distances' was deprecated in version 0.21 and will be + removed in 0.23. n_init : int, optional, default: 10 Number of time the k-means algorithm will be run with different @@ -295,6 +301,11 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', Returned only if `return_n_iter` is set to True. """ + if precompute_distances != 'not-used': + warnings.warn("'precompute_distances' was deprecated in version" + "0.21 and will be removed in 0.23.", + DeprecationWarning) + if n_init <= 0: raise ValueError("Invalid number of initializations." " n_init=%d must be bigger than zero." % n_init) @@ -315,20 +326,6 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', tol = _tolerance(X, tol) - # If the distances are precomputed every job will create a matrix of shape - # (n_clusters, n_samples). To stop KMeans from eating up memory we only - # activate this if the created matrix is guaranteed to be under 100MB. 12 - # million entries consume a little under 100MB if they are of type double. - if precompute_distances == 'auto': - n_samples = X.shape[0] - precompute_distances = (n_clusters * n_samples) < 12e6 - elif isinstance(precompute_distances, bool): - pass - else: - raise ValueError("precompute_distances should be 'auto' or True/False" - ", but a value of %r was passed" % - precompute_distances) - # Validate init array if hasattr(init, '__array__'): init = check_array(init, dtype=X.dtype.type, copy=True) @@ -367,41 +364,22 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', else: raise ValueError("Algorithm must be 'auto', 'full' or 'elkan', got" " %s" % str(algorithm)) - if effective_n_jobs(n_jobs) == 1: - # For a single thread, less memory is needed if we just store one set - # of the best results (as opposed to one set per run per thread). - for it in range(n_init): - # run a k-means once - labels, inertia, centers, n_iter_ = kmeans_single( - X, sample_weight, n_clusters, max_iter=max_iter, init=init, - verbose=verbose, precompute_distances=precompute_distances, - tol=tol, x_squared_norms=x_squared_norms, - random_state=random_state) - # determine if these results are the best so far - if best_inertia is None or inertia < best_inertia: - best_labels = labels.copy() - best_centers = centers.copy() - best_inertia = inertia - best_n_iter = n_iter_ - else: - # parallelisation of k-means runs - seeds = random_state.randint(np.iinfo(np.int32).max, size=n_init) - results = Parallel(n_jobs=n_jobs, verbose=0)( - delayed(kmeans_single)(X, sample_weight, n_clusters, - max_iter=max_iter, init=init, - verbose=verbose, tol=tol, - precompute_distances=precompute_distances, - x_squared_norms=x_squared_norms, - # Change seed to ensure variety - random_state=seed) - for seed in seeds) - # Get results with the lowest inertia - labels, inertia, centers, n_iters = zip(*results) - best = np.argmin(inertia) - best_labels = labels[best] - best_inertia = inertia[best] - best_centers = centers[best] - best_n_iter = n_iters[best] + + n_jobs_ = -1 if n_jobs is None else effective_n_jobs(n_jobs) + seeds = random_state.randint(np.iinfo(np.int32).max, size=n_init) + + for seed in seeds: + # run a k-means once + labels, inertia, centers, n_iter_ = kmeans_single( + X, sample_weight, n_clusters, max_iter=max_iter, init=init, + verbose=verbose, tol=tol, x_squared_norms=x_squared_norms, + random_state=seed, n_jobs=n_jobs_) + # determine if these results are the best so far + if best_inertia is None or inertia < best_inertia: + best_labels = labels.copy() + best_centers = centers.copy() + best_inertia = inertia + best_n_iter = n_iter_ if not sp.issparse(X): if not copy_x: @@ -423,37 +401,68 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, init='k-means++', verbose=False, x_squared_norms=None, - random_state=None, tol=1e-4, - precompute_distances=True): + random_state=None, tol=1e-4, n_jobs=None): if sp.issparse(X): raise TypeError("algorithm='elkan' not supported for sparse input X") + random_state = check_random_state(random_state) - if x_squared_norms is None: - x_squared_norms = row_norms(X, squared=True) + sample_weight = _check_sample_weight(X, sample_weight) + # init centers = _init_centroids(X, n_clusters, init, random_state=random_state, x_squared_norms=x_squared_norms) - centers = np.ascontiguousarray(centers) + if verbose: print('Initialization complete') - checked_sample_weight = _check_sample_weight(X, sample_weight) - centers, labels, n_iter = k_means_elkan(X, checked_sample_weight, - n_clusters, centers, tol=tol, - max_iter=max_iter, verbose=verbose) - if sample_weight is None: - inertia = np.sum((X - centers[labels]) ** 2, dtype=np.float64) - else: - sq_distances = np.sum((X - centers[labels]) ** 2, axis=1, - dtype=np.float64) * checked_sample_weight - inertia = np.sum(sq_distances, dtype=np.float64) - return labels, inertia, centers, n_iter + n_samples = X.shape[0] + + centers_old = np.zeros_like(centers) + center_half_distances = euclidean_distances(centers_old) / 2 + distance_next_center = np.zeros(n_clusters, dtype=X.dtype) + upper_bounds = np.zeros(n_samples, dtype=X.dtype) + lower_bounds = np.zeros((n_samples, n_clusters), dtype=X.dtype) + labels = np.full(n_samples, -1, dtype=np.int32) + weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype) + + _init_bounds(X, centers, center_half_distances, + labels, upper_bounds, lower_bounds) + + for i in range(max_iter): + _elkan_iter_chunked_dense(X, sample_weight, centers_old, centers, + weight_in_clusters, center_half_distances, + distance_next_center, upper_bounds, + lower_bounds, labels, n_jobs) + + if verbose: + inertia = _inertia_dense(X, sample_weight, centers_old, labels) + print("Iteration {0}, inertia {1}" .format(i, inertia)) + + center_shift_tot = squared_norm(centers - centers_old) + if center_shift_tot <= tol: + if verbose: + print("Converged at iteration {0}: " + "center shift {1} within tolerance {2}" + .format(i, center_shift_tot, tol)) + break + + if center_shift_tot > 0: + # rerun E-step in case of non-convergence so that predicted labels + # match cluster centers + _elkan_iter_chunked_dense(X, sample_weight, centers, centers, + weight_in_clusters, center_half_distances, + distance_next_center, upper_bounds, + lower_bounds, labels, n_jobs, + update_centers=False) + + inertia = _inertia_dense(X, sample_weight, centers, labels) + + return labels, inertia, centers, i + 1 def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, init='k-means++', verbose=False, x_squared_norms=None, - random_state=None, tol=1e-4, - precompute_distances=True): + random_state=None, tol=1e-4, n_jobs=-1): """A single run of k-means, assumes preparation completed prior. Parameters @@ -496,14 +505,14 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, x_squared_norms : array Precomputed x_squared_norms. - precompute_distances : boolean, default: True - Precompute distances (faster but takes more memory). - random_state : int, RandomState instance or None (default) Determines random number generation for centroid initialization. Use an int to make the randomness deterministic. See :term:`Glossary `. + n_jobs : int + The number of threads to be used. If -1, will use as many as possible. + Returns ------- centroid : float ndarray with shape (k, n_features) @@ -524,119 +533,60 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, sample_weight = _check_sample_weight(X, sample_weight) - best_labels, best_inertia, best_centers = None, None, None # init centers = _init_centroids(X, n_clusters, init, random_state=random_state, x_squared_norms=x_squared_norms) if verbose: print("Initialization complete") - # Allocate memory to store the distances for each sample to its - # closer center for reallocation in case of ties - distances = np.zeros(shape=(X.shape[0],), dtype=X.dtype) + centers_old = np.zeros_like(centers) + centers_squared_norms = np.zeros(n_clusters, dtype=X.dtype) + labels = np.full(X.shape[0], -1, dtype=np.int32) + weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype) + + if sp.issparse(X): + lloyd_iter = _lloyd_iter_chunked_sparse + _inertia = _inertia_sparse + else: + lloyd_iter = _lloyd_iter_chunked_dense + _inertia = _inertia_dense - # iterations for i in range(max_iter): - centers_old = centers.copy() - # labels assignment is also called the E-step of EM - labels, inertia = \ - _labels_inertia(X, sample_weight, x_squared_norms, centers, - precompute_distances=precompute_distances, - distances=distances) - - # computation of the means is also called the M-step of EM - if sp.issparse(X): - centers = _k_means._centers_sparse(X, sample_weight, labels, - n_clusters, distances) - else: - centers = _k_means._centers_dense(X, sample_weight, labels, - n_clusters, distances) + lloyd_iter(X, sample_weight, x_squared_norms, centers_old, centers, + centers_squared_norms, weight_in_clusters, labels, n_jobs) if verbose: - print("Iteration %2d, inertia %.3f" % (i, inertia)) + inertia = _inertia(X, sample_weight, centers_old, labels) + print("Iteration {0}, inertia {1}" .format(i, inertia)) - if best_inertia is None or inertia < best_inertia: - best_labels = labels.copy() - best_centers = centers.copy() - best_inertia = inertia - - center_shift_total = squared_norm(centers_old - centers) - if center_shift_total <= tol: + center_shift = squared_norm(centers - centers_old) + if center_shift <= tol: if verbose: - print("Converged at iteration %d: " - "center shift %e within tolerance %e" - % (i, center_shift_total, tol)) + print("Converged at iteration {0}: " + "center shift {1} within tolerance {2}" + .format(i, center_shift, tol)) break - if center_shift_total > 0: + if center_shift > 0: # rerun E-step in case of non-convergence so that predicted labels # match cluster centers - best_labels, best_inertia = \ - _labels_inertia(X, sample_weight, x_squared_norms, best_centers, - precompute_distances=precompute_distances, - distances=distances) - - return best_labels, best_inertia, best_centers, i + 1 - - -def _labels_inertia_precompute_dense(X, sample_weight, x_squared_norms, - centers, distances): - """Compute labels and inertia using a full distance matrix. - - This will overwrite the 'distances' array in-place. - - Parameters - ---------- - X : numpy array, shape (n_sample, n_features) - Input data. - - sample_weight : array-like, shape (n_samples,) - The weights for each observation in X. - - x_squared_norms : numpy array, shape (n_samples,) - Precomputed squared norms of X. - - centers : numpy array, shape (n_clusters, n_features) - Cluster centers which data is assigned to. - - distances : numpy array, shape (n_samples,) - Pre-allocated array in which distances are stored. + lloyd_iter(X, sample_weight, x_squared_norms, centers, centers, + centers_squared_norms, weight_in_clusters, labels, n_jobs, + update_centers=False) - Returns - ------- - labels : numpy array, dtype=np.int, shape (n_samples,) - Indices of clusters that samples are assigned to. + inertia = _inertia(X, sample_weight, centers, labels) - inertia : float - Sum of squared distances of samples to their closest cluster center. + return labels, inertia, centers, i + 1 - """ - n_samples = X.shape[0] - # Breakup nearest neighbor distance computation into batches to prevent - # memory blowup in the case of a large number of samples and clusters. - # TODO: Once PR #7383 is merged use check_inputs=False in metric_kwargs. - labels, mindist = pairwise_distances_argmin_min( - X=X, Y=centers, metric='euclidean', metric_kwargs={'squared': True}) - # cython k-means code assumes int32 inputs - labels = labels.astype(np.int32) - if n_samples == distances.shape[0]: - # distances will be changed in-place - distances[:] = mindist - inertia = (mindist * sample_weight).sum() - return labels, inertia - - -def _labels_inertia(X, sample_weight, x_squared_norms, centers, - precompute_distances=True, distances=None): +def _labels_inertia(X, sample_weight, x_squared_norms, centers): """E step of the K-means EM algorithm. Compute the labels and the inertia of the given samples and centers. - This will compute the distances in-place. Parameters ---------- - X : float64 array-like or CSR sparse matrix, shape (n_samples, n_features) + X : float array-like or CSR sparse matrix, shape (n_samples, n_features) The input samples to assign to the labels. sample_weight : array-like, shape (n_samples,) @@ -646,19 +596,12 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers, Precomputed squared euclidean norm of each data point, to speed up computations. - centers : float array, shape (k, n_features) + centers : float array, shape (n_clusters, n_features) The cluster centers. - precompute_distances : boolean, default: True - Precompute distances (faster but takes more memory). - - distances : float array, shape (n_samples,) - Pre-allocated array to be filled in with each sample's distance - to the closest center. - Returns ------- - labels : int array of shape(n) + labels : int array, shape (n_samples,) The resulting assignment inertia : float @@ -666,24 +609,23 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers, """ n_samples = X.shape[0] sample_weight = _check_sample_weight(X, sample_weight) - # set the default value of centers to -1 to be able to detect any anomaly - # easily - labels = np.full(n_samples, -1, np.int32) - if distances is None: - distances = np.zeros(shape=(0,), dtype=X.dtype) - # distances will be changed in-place + labels = np.full(n_samples, -1, dtype=np.int32) + centers_squared_norms = np.zeros(centers.shape[0], dtype=centers.dtype) + weight_in_clusters = np.zeros_like(centers_squared_norms) + if sp.issparse(X): - inertia = _k_means._assign_labels_csr( - X, sample_weight, x_squared_norms, centers, labels, - distances=distances) + labels_centers = _lloyd_iter_chunked_sparse + _inertia = _inertia_sparse else: - if precompute_distances: - return _labels_inertia_precompute_dense(X, sample_weight, - x_squared_norms, centers, - distances) - inertia = _k_means._assign_labels_array( - X, sample_weight, x_squared_norms, centers, labels, - distances=distances) + labels_centers = _lloyd_iter_chunked_dense + _inertia = _inertia_dense + + labels_centers(X, sample_weight, x_squared_norms, centers, + centers, centers_squared_norms, weight_in_clusters, + labels, update_centers=False) + + inertia = _inertia(X, sample_weight, centers, labels) + return labels, inertia @@ -814,6 +756,9 @@ class KMeans(BaseEstimator, ClusterMixin, TransformerMixin): True : always precompute distances False : never precompute distances + .. deprecated:: 0.21 + 'precompute_distances' was deprecated in version 0.21 and will be + removed in 0.23. verbose : int, default 0 Verbosity mode. @@ -868,8 +813,8 @@ class KMeans(BaseEstimator, ClusterMixin, TransformerMixin): >>> from sklearn.cluster import KMeans >>> import numpy as np >>> X = np.array([[1, 2], [1, 4], [1, 0], - ... [10, 2], [10, 4], [10, 0]]) - >>> kmeans = KMeans(n_clusters=2, random_state=0).fit(X) + ... [4, 2], [4, 4], [4, 0]]) + >>> kmeans = KMeans(n_clusters=2, random_state=1234).fit(X) >>> kmeans.labels_ array([1, 1, 1, 0, 0, 0], dtype=int32) >>> kmeans.predict([[0, 0], [12, 3]]) @@ -912,7 +857,7 @@ class KMeans(BaseEstimator, ClusterMixin, TransformerMixin): """ def __init__(self, n_clusters=8, init='k-means++', n_init=10, - max_iter=300, tol=1e-4, precompute_distances='auto', + max_iter=300, tol=1e-4, precompute_distances='not-used', verbose=0, random_state=None, copy_x=True, n_jobs=None, algorithm='auto'): @@ -957,6 +902,11 @@ def fit(self, X, y=None, sample_weight=None): are assigned equal weight (default: None) """ + if self.precompute_distances != 'not-used': + warnings.warn("'precompute_distances' was deprecated in version" + "0.21 and will be removed in 0.23.", + DeprecationWarning) + random_state = check_random_state(self.random_state) self.cluster_centers_, self.labels_, self.inertia_, self.n_iter_ = \ @@ -1074,6 +1024,7 @@ def predict(self, X, sample_weight=None): X = self._check_test_data(X) x_squared_norms = row_norms(X, squared=True) + return _labels_inertia(X, sample_weight, x_squared_norms, self.cluster_centers_)[0] @@ -1101,6 +1052,7 @@ def score(self, X, y=None, sample_weight=None): X = self._check_test_data(X) x_squared_norms = row_norms(X, squared=True) + return -_labels_inertia(X, sample_weight, x_squared_norms, self.cluster_centers_)[1] @@ -1173,8 +1125,7 @@ def _mini_batch_step(X, sample_weight, x_squared_norms, centers, weight_sums, """ # Perform label assignment to nearest centers nearest_center, inertia = _labels_inertia(X, sample_weight, - x_squared_norms, centers, - distances=distances) + x_squared_norms, centers) if random_reassign and reassignment_ratio > 0: random_state = check_random_state(random_state) @@ -1208,7 +1159,7 @@ def _mini_batch_step(X, sample_weight, x_squared_norms, centers, weight_sums, # implementation for the sparse CSR representation completely written in # cython if sp.issparse(X): - return inertia, _k_means._mini_batch_update_csr( + return inertia, _mini_batch_update_csr( X, sample_weight, x_squared_norms, centers, weight_sums, nearest_center, old_center_buffer, compute_squared_diff) @@ -1424,8 +1375,8 @@ class MiniBatchKMeans(KMeans): >>> kmeans = kmeans.partial_fit(X[0:6,:]) >>> kmeans = kmeans.partial_fit(X[6:12,:]) >>> kmeans.cluster_centers_ - array([[1, 1], - [3, 4]]) + array([[2. , 1. ], + [3.5, 4.5]]) >>> kmeans.predict([[0, 0], [4, 4]]) array([0, 1], dtype=int32) >>> # fit on the whole data @@ -1669,7 +1620,8 @@ def partial_fit(self, X, y=None, sample_weight=None): """ - X = check_array(X, accept_sparse="csr", order="C") + X = check_array(X, accept_sparse="csr", order='C', + dtype=[np.float64, np.float32]) n_samples, n_features = X.shape if hasattr(self.init, '__array__'): self.init = np.ascontiguousarray(self.init, dtype=X.dtype) diff --git a/sklearn/cluster/setup.py b/sklearn/cluster/setup.py index 99c4dcd6177b0..75b3e355138e4 100644 --- a/sklearn/cluster/setup.py +++ b/sklearn/cluster/setup.py @@ -29,26 +29,36 @@ def configuration(parent_package='', top_path=None): language="c++", include_dirs=[numpy.get_include()], libraries=libraries) - config.add_extension('_k_means_elkan', - sources=['_k_means_elkan.pyx'], + + config.add_extension('_k_means', + sources=['_k_means.pyx'], include_dirs=[numpy.get_include()], libraries=libraries) - config.add_extension('_k_means', + config.add_extension('_k_means_lloyd', libraries=cblas_libs, - sources=['_k_means.pyx'], + sources=['_k_means_lloyd.pyx'], include_dirs=[join('..', 'src', 'cblas'), numpy.get_include(), blas_info.pop('include_dirs', [])], + extra_link_args=['-fopenmp'], extra_compile_args=blas_info.pop( - 'extra_compile_args', []), + 'extra_compile_args', []) + ['-fopenmp'], **blas_info ) + config.add_extension('_k_means_elkan', + sources=['_k_means_elkan.pyx'], + include_dirs=[numpy.get_include()], + libraries=libraries, + extra_link_args=['-fopenmp'], + extra_compile_args=['-fopenmp']) + config.add_subpackage('tests') return config + if __name__ == '__main__': from numpy.distutils.core import setup setup(**configuration(top_path='').todict()) diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 37571d427002b..5029609684487 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -234,6 +234,16 @@ def test_k_means_new_centers(): np.testing.assert_array_equal(this_labels, labels) +def test_k_means_precompute_distances_deprecated(): + # check that the deprecation warning is raised for precompute_distances + with pytest.warns(DeprecationWarning, match='precompute_distances'): + km = KMeans(precompute_distances='auto') + km.fit(X) + + with pytest.warns(DeprecationWarning, match='precompute_distances'): + k_means(X, n_clusters, precompute_distances='auto') + + @if_safe_multiprocessing_with_blas def test_k_means_plus_plus_init_2_jobs(): km = KMeans(init="k-means++", n_clusters=n_clusters, n_jobs=2, @@ -241,25 +251,6 @@ def test_k_means_plus_plus_init_2_jobs(): _check_fitted_model(km) -def test_k_means_precompute_distances_flag(): - # check that a warning is raised if the precompute_distances flag is not - # supported - km = KMeans(precompute_distances="wrong") - assert_raises(ValueError, km.fit, X) - - -def test_k_means_plus_plus_init_not_precomputed(): - km = KMeans(init="k-means++", n_clusters=n_clusters, random_state=42, - precompute_distances=False).fit(X) - _check_fitted_model(km) - - -def test_k_means_random_init_not_precomputed(): - km = KMeans(init="random", n_clusters=n_clusters, random_state=42, - precompute_distances=False).fit(X) - _check_fitted_model(km) - - @pytest.mark.parametrize('data', [X, X_csr], ids=['dense', 'sparse']) @pytest.mark.parametrize('init', ['random', 'k-means++', centers.copy()]) def test_k_means_init(data, init): @@ -310,8 +301,7 @@ def test_k_means_fortran_aligned_data(): X = np.asfortranarray([[0, 0], [0, 1], [0, 1]]) centers = np.array([[0, 0], [0, 1]]) labels = np.array([0, 1, 1]) - km = KMeans(n_init=1, init=centers, precompute_distances=False, - random_state=42, n_clusters=2) + km = KMeans(n_init=1, init=centers, random_state=42, n_clusters=2) km.fit(X) assert_array_almost_equal(km.cluster_centers_, centers) assert_array_equal(km.labels_, labels) diff --git a/sklearn/utils/sparsefuncs_fast.pyx b/sklearn/utils/sparsefuncs_fast.pyx index 4e13fce315c57..ba04bd54aba42 100644 --- a/sklearn/utils/sparsefuncs_fast.pyx +++ b/sklearn/utils/sparsefuncs_fast.pyx @@ -25,35 +25,36 @@ ctypedef fused integral: ctypedef np.float64_t DOUBLE + def csr_row_norms(X): """L2 norm of each row in CSR matrix X.""" if X.dtype not in [np.float32, np.float64]: X = X.astype(np.float64) - return _csr_row_norms(X.data, X.shape, X.indices, X.indptr) + + norms = np.zeros(X.shape[0], dtype=X.data.dtype) + _csr_row_norms(X.data, X.shape, X.indices, X.indptr, norms) + + return norms def _csr_row_norms(np.ndarray[floating, ndim=1, mode="c"] X_data, shape, np.ndarray[integral, ndim=1, mode="c"] X_indices, - np.ndarray[integral, ndim=1, mode="c"] X_indptr): + np.ndarray[integral, ndim=1, mode="c"] X_indptr, + floating[::1] norms): cdef: unsigned long long n_samples = shape[0] - unsigned long long n_features = shape[1] - np.ndarray[DOUBLE, ndim=1, mode="c"] norms - - np.npy_intp i, j + + unsigned long long i + integral j double sum_ - norms = np.zeros(n_samples, dtype=np.float64) - for i in range(n_samples): sum_ = 0.0 for j in range(X_indptr[i], X_indptr[i + 1]): sum_ += X_data[j] * X_data[j] norms[i] = sum_ - return norms - def csr_mean_variance_axis0(X): """Compute mean and variance along axis 0 on a CSR matrix diff --git a/sklearn/utils/tests/test_sparsefuncs.py b/sklearn/utils/tests/test_sparsefuncs.py index 6a4596634f28d..781184a6fc173 100644 --- a/sklearn/utils/tests/test_sparsefuncs.py +++ b/sklearn/utils/tests/test_sparsefuncs.py @@ -18,7 +18,8 @@ count_nonzero, csc_median_axis_0) from sklearn.utils.sparsefuncs_fast import (assign_rows_csr, inplace_csr_row_normalize_l1, - inplace_csr_row_normalize_l2) + inplace_csr_row_normalize_l2, + csr_row_norms) from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_allclose @@ -522,3 +523,16 @@ def test_inplace_normalize(): if inplace_csr_row_normalize is inplace_csr_row_normalize_l2: X_csr.data **= 2 assert_array_almost_equal(np.abs(X_csr).sum(axis=1), ones) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_csr_row_norms(dtype): + # checks that csr_row_norms returns the same output as + # scipy.sparse.linalg.norm, and that the dype is the same X's. + X = sp.random(100, 10, format='csr', dtype=dtype) + + scipy_norms = sp.linalg.norm(X, axis=1)**2 + norms = csr_row_norms(X) + + assert norms.dtype.type is dtype + assert_array_almost_equal(norms, scipy_norms) From 7966dd0e74a8aa6003966f18b63fe08866d9b858 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 22 Oct 2018 12:11:26 +0200 Subject: [PATCH 002/163] elkan center_half_distance init to 0 & out center_shift --- sklearn/cluster/_k_means_elkan.pyx | 13 ++++++------- sklearn/cluster/k_means_.py | 23 ++++++++++++----------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index ce41c2534e227..1d5abdb6faa83 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -114,12 +114,13 @@ cpdef void _elkan_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[::1] sample_weight, floating[:, ::1] centers_old, floating[:, ::1] centers_new, - floating[::1] weight_in_clusters, + floating[::1] weight_in_clusters, + int[::1] labels, + floating[::1] center_shift, floating[:, ::1] center_half_distances, floating[::1] distance_next_center, floating[::1] upper_bounds, floating[:, ::1] lower_bounds, - int[::1] labels, int n_jobs = -1, bint update_centers = True): """Single interation of K-means elkan algorithm @@ -147,6 +148,9 @@ cpdef void _elkan_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, Placeholder for the sums of the weights of every observation assigned to each center. + labels : int array-like, shape (n_samples,) + labels assignment. + center_half_distances : {float32, float64} array-like, \ shape (n_clusters, n_clusters) Half pairwise distances between centers. @@ -162,9 +166,6 @@ shape (n_clusters, n_clusters) Lower bound for the distance between each sample and each center, updated inplace. - labels : int array-like, shape (n_samples,) - labels assignment. - n_jobs : int The number of threads to be used by openmp. If -1, openmp will use as many as possible. @@ -195,8 +196,6 @@ shape (n_clusters, n_clusters) floating *centers_new_chunk floating *weight_in_clusters_chunk - floating[::1] center_shift = np.zeros(n_clusters, dtype=X.dtype) - # count remainder chunk in total number of chunks n_chunks += n_samples != n_chunks * n_samples_chunk diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index f6a71671f908e..dc49ce02af780 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -418,27 +418,28 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, n_samples = X.shape[0] centers_old = np.zeros_like(centers) - center_half_distances = euclidean_distances(centers_old) / 2 + labels = np.full(n_samples, -1, dtype=np.int32) + weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype) + center_shift = np.zeros(n_clusters, dtype=X.dtype) + center_half_distances = np.zeros((n_clusters, n_clusters), dtype=X.dtype) distance_next_center = np.zeros(n_clusters, dtype=X.dtype) upper_bounds = np.zeros(n_samples, dtype=X.dtype) lower_bounds = np.zeros((n_samples, n_clusters), dtype=X.dtype) - labels = np.full(n_samples, -1, dtype=np.int32) - weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype) _init_bounds(X, centers, center_half_distances, labels, upper_bounds, lower_bounds) for i in range(max_iter): _elkan_iter_chunked_dense(X, sample_weight, centers_old, centers, - weight_in_clusters, center_half_distances, - distance_next_center, upper_bounds, - lower_bounds, labels, n_jobs) + weight_in_clusters, labels, center_shift, + center_half_distances, distance_next_center, + upper_bounds, lower_bounds, n_jobs) if verbose: inertia = _inertia_dense(X, sample_weight, centers_old, labels) print("Iteration {0}, inertia {1}" .format(i, inertia)) - center_shift_tot = squared_norm(centers - centers_old) + center_shift_tot = (center_shift**2).sum() if center_shift_tot <= tol: if verbose: print("Converged at iteration {0}: " @@ -449,10 +450,10 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, if center_shift_tot > 0: # rerun E-step in case of non-convergence so that predicted labels # match cluster centers - _elkan_iter_chunked_dense(X, sample_weight, centers, centers, - weight_in_clusters, center_half_distances, - distance_next_center, upper_bounds, - lower_bounds, labels, n_jobs, + _elkan_iter_chunked_dense(X, sample_weight, centers_old, centers, + weight_in_clusters, labels, center_shift, + center_half_distances, distance_next_center, + upper_bounds, lower_bounds, n_jobs, update_centers=False) inertia = _inertia_dense(X, sample_weight, centers, labels) From 97fcf1f7555de54fce30a840bc5bc23eb1fbb75c Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 22 Oct 2018 12:31:35 +0200 Subject: [PATCH 003/163] out center_shift & numpy computations on pairwise_distances --- sklearn/cluster/_k_means_elkan.pyx | 11 ++--------- sklearn/cluster/k_means_.py | 27 +++++++++++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 1d5abdb6faa83..4f08ec6065891 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -93,8 +93,6 @@ cpdef _init_bounds(np.ndarray[floating, ndim=2, mode='c'] X, floating min_dist, dist int best_cluster, i, j - center_half_distances = euclidean_distances(np.asarray(centers)) / 2 - for i in range(n_samples): best_cluster = 0 min_dist = euclidean_dist(&X[i, 0], ¢ers[0, 0], n_features) @@ -115,12 +113,12 @@ cpdef void _elkan_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[:, ::1] centers_old, floating[:, ::1] centers_new, floating[::1] weight_in_clusters, - int[::1] labels, - floating[::1] center_shift, floating[:, ::1] center_half_distances, floating[::1] distance_next_center, floating[::1] upper_bounds, floating[:, ::1] lower_bounds, + int[::1] labels, + floating[::1] center_shift, int n_jobs = -1, bint update_centers = True): """Single interation of K-means elkan algorithm @@ -205,9 +203,6 @@ shape (n_clusters, n_clusters) memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) - # compute pairwise distances between centers and get next closest center - distance_next_center = np.partition(np.asarray(center_half_distances), kth=1, axis=0)[1] - # set number of threads to be used by openmp num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() @@ -281,8 +276,6 @@ shape (n_clusters, n_clusters) if lower_bounds[i, j] < 0: lower_bounds[i, j] = 0 - center_half_distances = euclidean_distances(np.asarray(centers_old)) / 2 - cdef void _update_chunk(floating *X, floating *sample_weight, diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index dc49ce02af780..9cb35a9ec42d1 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -418,22 +418,29 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, n_samples = X.shape[0] centers_old = np.zeros_like(centers) - labels = np.full(n_samples, -1, dtype=np.int32) weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype) - center_shift = np.zeros(n_clusters, dtype=X.dtype) - center_half_distances = np.zeros((n_clusters, n_clusters), dtype=X.dtype) + labels = np.full(n_samples, -1, dtype=np.int32) + center_half_distances = euclidean_distances(centers) / 2 distance_next_center = np.zeros(n_clusters, dtype=X.dtype) upper_bounds = np.zeros(n_samples, dtype=X.dtype) lower_bounds = np.zeros((n_samples, n_clusters), dtype=X.dtype) + center_shift = np.zeros(n_clusters, dtype=X.dtype) _init_bounds(X, centers, center_half_distances, labels, upper_bounds, lower_bounds) for i in range(max_iter): + # compute the closest other center of each center + distance_next_center = np.partition(np.asarray(center_half_distances), + kth=1, axis=0)[1] + _elkan_iter_chunked_dense(X, sample_weight, centers_old, centers, - weight_in_clusters, labels, center_shift, - center_half_distances, distance_next_center, - upper_bounds, lower_bounds, n_jobs) + weight_in_clusters, center_half_distances, + distance_next_center, upper_bounds, + lower_bounds, labels, center_shift, n_jobs) + + # compute new pairwise distances between centers for next iterations + center_half_distances = euclidean_distances(centers) / 2 if verbose: inertia = _inertia_dense(X, sample_weight, centers_old, labels) @@ -450,10 +457,10 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, if center_shift_tot > 0: # rerun E-step in case of non-convergence so that predicted labels # match cluster centers - _elkan_iter_chunked_dense(X, sample_weight, centers_old, centers, - weight_in_clusters, labels, center_shift, - center_half_distances, distance_next_center, - upper_bounds, lower_bounds, n_jobs, + _elkan_iter_chunked_dense(X, sample_weight, centers, centers, + weight_in_clusters, center_half_distances, + distance_next_center, upper_bounds, + lower_bounds, labels, center_shift, n_jobs, update_centers=False) inertia = _inertia_dense(X, sample_weight, centers, labels) From 78a167d784c876d6092cf3e6f8caff7361feaae8 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 22 Oct 2018 12:45:20 +0200 Subject: [PATCH 004/163] comment --- sklearn/cluster/_k_means_elkan.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 4f08ec6065891..21c1060eb7d51 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -179,8 +179,8 @@ shape (n_clusters, n_clusters) int n_features = X.shape[1] int n_clusters = centers_new.shape[0] - # hard-coded number of samples per chunk. Appeared to be close to - # optimal in all situations. + # hard-coded number of samples per chunk. Splitting in chunks is + # necessary to get parallelism. Chunk size chosed to be same as lloyd's int n_samples_chunk = 256 if n_samples > 256 else n_samples int n_chunks = n_samples // n_samples_chunk int n_samples_r = n_samples % n_samples_chunk From 35fd78e2abe9439e4c0ebf3831c1cce900132ee9 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 22 Oct 2018 15:17:46 +0200 Subject: [PATCH 005/163] error message minibatchkmeans partial_fit different number of features --- sklearn/cluster/k_means_.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 9cb35a9ec42d1..57983e8c4a6ed 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -1663,6 +1663,13 @@ def partial_fit(self, X, y=None, sample_weight=None): 10 * (1 + self.counts_.min())) == 0 distances = np.zeros(X.shape[0], dtype=X.dtype) + # Raise error if partial_fit called on data with different number + # of features. + if X.shape[1] != self.cluster_centers_.shape[1]: + raise ValueError( + "Number of features %d does not match previous " + "data %d." % (X.shape[1], self.cluster_centers_.shape[1])) + _mini_batch_step(X, sample_weight, x_squared_norms, self.cluster_centers_, self.counts_, np.zeros(0, dtype=X.dtype), 0, From 6dae806e05636119d9e11a8da2b9846324dca7f2 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 22 Oct 2018 15:19:57 +0200 Subject: [PATCH 006/163] drop python 2 CI --- .circleci/config.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index 89cc103ec6301..f242e4a516edb 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -103,6 +103,7 @@ jobs: command: | if [[ "${CIRCLE_BRANCH}" =~ ^master$|^[0-9]+\.[0-9]+\.X$ ]]; then bash build_tools/circle/push_doc.sh doc/_build/html/stable +<<<<<<< 6c0faf614b525bad520269e28ec684b44c00c22a fi workflows: @@ -130,3 +131,6 @@ workflows: - master jobs: - pypy3 +======= + fi +>>>>>>> drop python 2 CI From f5c0aa1628d6a0f2ca10df4d969556a7c4bf1227 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 22 Oct 2018 16:50:28 +0200 Subject: [PATCH 007/163] refactor center_shift computation --- sklearn/cluster/_k_means.pyx | 28 ++++++++++++++++++++++++++++ sklearn/cluster/_k_means_elkan.pyx | 22 +++++----------------- sklearn/cluster/_k_means_lloyd.pyx | 22 +++++++++------------- sklearn/cluster/k_means_.py | 18 ++++++++++-------- 4 files changed, 52 insertions(+), 38 deletions(-) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 382efa6969666..74bda4f00f47f 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -12,6 +12,7 @@ import numpy as np cimport numpy as np cimport cython from cython cimport floating +from libc.math cimport sqrt np.import_array() @@ -182,6 +183,33 @@ cpdef void _relocate_empty_clusters_sparse(floating[::1] X_data, weight_in_clusters[old_cluster_id] -= weight +cpdef void _mean_and_center_shift(floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[::1] weight_in_clusters, + floating[::1] center_shift): + cdef: + int n_clusters = centers_old.shape[0] + int n_features = centers_old.shape[1] + + int j, k + floating alpha, tmp, x + + # average new centers wrt sample weights + for j in xrange(n_clusters): + if weight_in_clusters[j] > 0: + alpha = 1.0 / weight_in_clusters[j] + for k in xrange(n_features): + centers_new[j, k] *= alpha + + # compute shift distance between old and new centers + for j in range(n_clusters): + tmp = 0 + for k in range(n_features): + x = centers_new[j, k] - centers_old[j, k] + tmp += x * x + center_shift[j] = sqrt(tmp) + + def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] sample_weight, np.ndarray[floating, ndim=1] x_squared_norms, np.ndarray[floating, ndim=2] centers, diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 21c1060eb7d51..077267a017688 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -15,7 +15,7 @@ from libc.stdlib cimport malloc, free from libc.string cimport memset, memcpy from ..metrics import euclidean_distances -from ._k_means import _relocate_empty_clusters_dense +from ._k_means import _relocate_empty_clusters_dense, _mean_and_center_shift np.import_array() @@ -252,22 +252,10 @@ shape (n_clusters, n_clusters) _relocate_empty_clusters_dense(X, sample_weight, centers_new, weight_in_clusters, labels) - # average new centers wrt sample weights - for j in xrange(n_clusters): - if weight_in_clusters[j] > 0: - alpha = 1.0 / weight_in_clusters[j] - for k in xrange(n_features): - centers_new[j, k] *= alpha - - # compute shift distance between old and new centers - for j in range(n_clusters): - tmp = 0 - for k in range(n_features): - x = centers_new[j, k] - centers_old[j, k] - tmp += x * x - center_shift[j] = sqrt(tmp) - - # update lower and upper bounds accordingly + _mean_and_center_shift(centers_old, centers_new, weight_in_clusters, + center_shift) + + # update lower and upper bounds for i in range(n_samples): upper_bounds[i] += center_shift[labels[i]] diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index c1c1f980be35b..94e84f04487db 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -9,11 +9,13 @@ cimport openmp from cython cimport floating from cython.parallel import prange, parallel from scipy.linalg.cython_blas cimport sgemm, dgemm +from libc.math cimport sqrt from libc.stdlib cimport malloc, free from libc.string cimport memset, memcpy from ._k_means import (_relocate_empty_clusters_dense, - _relocate_empty_clusters_sparse) + _relocate_empty_clusters_sparse, + _mean_and_center_shift) np.import_array() @@ -41,6 +43,7 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[::1] centers_squared_norms, floating[::1] weight_in_clusters, int[::1] labels, + floating[::1] center_shift, int n_jobs = -1, bint update_centers = True): """Single interation of K-means lloyd algorithm @@ -179,12 +182,8 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, _relocate_empty_clusters_dense(X, sample_weight, centers_new, weight_in_clusters, labels) - # average new centers wrt sample weights - for j in xrange(n_clusters): - if weight_in_clusters[j] > 0: - alpha = 1.0 / weight_in_clusters[j] - for k in xrange(n_features): - centers_new[j, k] *= alpha + _mean_and_center_shift(centers_old, centers_new, weight_in_clusters, + center_shift) cdef void _update_chunk_dense(floating *X, @@ -253,6 +252,7 @@ cpdef void _lloyd_iter_chunked_sparse(X, floating[::1] centers_squared_norms, floating[::1] weight_in_clusters, int[::1] labels, + floating[::1] center_shift, int n_jobs = -1, bint update_centers = True): """Single interation of K-means lloyd algorithm @@ -391,12 +391,8 @@ cpdef void _lloyd_iter_chunked_sparse(X, sample_weight, centers_new, weight_in_clusters, labels) - # average new centers wrt sample weights - for j in xrange(n_clusters): - if weight_in_clusters[j] > 0: - alpha = 1.0 / weight_in_clusters[j] - for k in xrange(n_features): - centers_new[j, k] *= alpha + _mean_and_center_shift(centers_old, centers_new, weight_in_clusters, + center_shift) cdef void _update_chunk_sparse(floating *X_data, diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 57983e8c4a6ed..48e6c4a171978 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -19,7 +19,7 @@ from ..base import BaseEstimator, ClusterMixin, TransformerMixin from ..metrics.pairwise import euclidean_distances -from ..utils.extmath import row_norms, squared_norm, stable_cumsum +from ..utils.extmath import row_norms, stable_cumsum from ..utils.sparsefuncs_fast import assign_rows_csr from ..utils.sparsefuncs import mean_variance_axis from ..utils.validation import _num_samples @@ -551,6 +551,7 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, centers_squared_norms = np.zeros(n_clusters, dtype=X.dtype) labels = np.full(X.shape[0], -1, dtype=np.int32) weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype) + center_shift = np.zeros(n_clusters, dtype=X.dtype) if sp.issparse(X): lloyd_iter = _lloyd_iter_chunked_sparse @@ -561,26 +562,27 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, for i in range(max_iter): lloyd_iter(X, sample_weight, x_squared_norms, centers_old, centers, - centers_squared_norms, weight_in_clusters, labels, n_jobs) + centers_squared_norms, weight_in_clusters, labels, + center_shift, n_jobs) if verbose: inertia = _inertia(X, sample_weight, centers_old, labels) print("Iteration {0}, inertia {1}" .format(i, inertia)) - center_shift = squared_norm(centers - centers_old) - if center_shift <= tol: + center_shift_tot = (center_shift**2).sum() + if center_shift_tot <= tol: if verbose: print("Converged at iteration {0}: " "center shift {1} within tolerance {2}" - .format(i, center_shift, tol)) + .format(i, center_shift_tot, tol)) break - if center_shift > 0: + if center_shift_tot > 0: # rerun E-step in case of non-convergence so that predicted labels # match cluster centers lloyd_iter(X, sample_weight, x_squared_norms, centers, centers, - centers_squared_norms, weight_in_clusters, labels, n_jobs, - update_centers=False) + centers_squared_norms, weight_in_clusters, labels, + center_shift, n_jobs, update_centers=False) inertia = _inertia(X, sample_weight, centers, labels) From a1c1facbddf48b05ef7efc28d3cfdbbd28ebfd05 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 22 Oct 2018 17:00:31 +0200 Subject: [PATCH 008/163] docstring --- sklearn/cluster/_k_means.pyx | 1 + sklearn/cluster/_k_means_elkan.pyx | 9 ++++++--- sklearn/cluster/_k_means_lloyd.pyx | 6 ++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 74bda4f00f47f..9ec65bf98f5b0 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -187,6 +187,7 @@ cpdef void _mean_and_center_shift(floating[:, ::1] centers_old, floating[:, ::1] centers_new, floating[::1] weight_in_clusters, floating[::1] center_shift): + """Average new centers wrt weights and compute center shift""" cdef: int n_clusters = centers_old.shape[0] int n_features = centers_old.shape[1] diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 077267a017688..150c8d93acb42 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -146,9 +146,6 @@ cpdef void _elkan_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, Placeholder for the sums of the weights of every observation assigned to each center. - labels : int array-like, shape (n_samples,) - labels assignment. - center_half_distances : {float32, float64} array-like, \ shape (n_clusters, n_clusters) Half pairwise distances between centers. @@ -164,6 +161,12 @@ shape (n_clusters, n_clusters) Lower bound for the distance between each sample and each center, updated inplace. + labels : int array-like, shape (n_samples,) + labels assignment. + + center_shift : {float32, float64} array-like, shape (n_clusters,) + Distance between old and new centers. + n_jobs : int The number of threads to be used by openmp. If -1, openmp will use as many as possible. diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index 94e84f04487db..359e203c4c4dd 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -79,6 +79,9 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, labels : int array-like, shape (n_samples,) labels assignment. + + center_shift : {float32, float64} array-like, shape (n_clusters,) + Distance between old and new centers. n_jobs : int The number of threads to be used by openmp. If -1, openmp will use as @@ -288,6 +291,9 @@ cpdef void _lloyd_iter_chunked_sparse(X, labels : int array-like, shape (n_samples,) labels assignment. + + center_shift : {float32, float64} array-like, shape (n_clusters,) + Distance between old and new centers. n_jobs : int The number of threads to be used by openmp. If -1, openmp will use as From 8df3b1e628a85bc815faaf4721cf401327855c19 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 22 Oct 2018 19:02:24 +0200 Subject: [PATCH 009/163] fix center_shift --- sklearn/cluster/k_means_.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 48e6c4a171978..311494b9a66e3 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -622,6 +622,7 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers): labels = np.full(n_samples, -1, dtype=np.int32) centers_squared_norms = np.zeros(centers.shape[0], dtype=centers.dtype) weight_in_clusters = np.zeros_like(centers_squared_norms) + center_shift = np.zeros_like(centers_squared_norms) if sp.issparse(X): labels_centers = _lloyd_iter_chunked_sparse @@ -632,7 +633,7 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers): labels_centers(X, sample_weight, x_squared_norms, centers, centers, centers_squared_norms, weight_in_clusters, - labels, update_centers=False) + labels, center_shift, update_centers=False) inertia = _inertia(X, sample_weight, centers, labels) From 8f111c7deaca6b3d71a2017e3bfd9398c6b4d012 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 22 Oct 2018 19:02:39 +0200 Subject: [PATCH 010/163] update tests --- sklearn/cluster/tests/test_k_means.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 5029609684487..8d213f4310acb 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -234,16 +234,6 @@ def test_k_means_new_centers(): np.testing.assert_array_equal(this_labels, labels) -def test_k_means_precompute_distances_deprecated(): - # check that the deprecation warning is raised for precompute_distances - with pytest.warns(DeprecationWarning, match='precompute_distances'): - km = KMeans(precompute_distances='auto') - km.fit(X) - - with pytest.warns(DeprecationWarning, match='precompute_distances'): - k_means(X, n_clusters, precompute_distances='auto') - - @if_safe_multiprocessing_with_blas def test_k_means_plus_plus_init_2_jobs(): km = KMeans(init="k-means++", n_clusters=n_clusters, n_jobs=2, @@ -301,7 +291,8 @@ def test_k_means_fortran_aligned_data(): X = np.asfortranarray([[0, 0], [0, 1], [0, 1]]) centers = np.array([[0, 0], [0, 1]]) labels = np.array([0, 1, 1]) - km = KMeans(n_init=1, init=centers, random_state=42, n_clusters=2) + km = KMeans(n_init=1, init=centers, precompute_distances=False, + random_state=42, n_clusters=2) km.fit(X) assert_array_almost_equal(km.cluster_centers_, centers) assert_array_equal(km.labels_, labels) From e8be35428d60ae2aa8109b44ecd52106068675ff Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 25 Oct 2018 12:35:51 +0200 Subject: [PATCH 011/163] range consistency --- sklearn/cluster/_k_means.pyx | 30 +++++----- sklearn/cluster/_k_means_elkan.pyx | 4 +- sklearn/cluster/_k_means_lloyd.pyx | 89 ++++++++++++------------------ 3 files changed, 51 insertions(+), 72 deletions(-) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 9ec65bf98f5b0..7e619532daceb 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -38,10 +38,10 @@ cpdef floating _inertia_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating inertia = 0.0 - for i in xrange(n_samples): + for i in range(n_samples): j = labels[i] sample_inertia = 0.0 - for k in xrange(n_features): + for k in range(n_features): tmp = X[i, k] - centers[j, k] sample_inertia += tmp * tmp inertia += sample_inertia * sample_weight[i] @@ -70,13 +70,13 @@ cpdef floating _inertia_sparse(X, floating inertia = 0.0 - for i in xrange(n_samples): + for i in range(n_samples): j = labels[i] sample_inertia = 0.0 row_ptr = X_indptr[i] nz_len = X_indptr[i + 1] - X_indptr[i] nz_ptr = 0 - for k in xrange(n_features): + for k in range(n_features): if nz_ptr < nz_len and k == X_indices[row_ptr + nz_ptr]: tmp = X_data[row_ptr + nz_ptr] - centers[j, k] nz_ptr += 1 @@ -112,7 +112,7 @@ cpdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] floating weight if n_empty > 0: - for idx in xrange(n_empty): + for idx in range(n_empty): new_cluster_id = empty_clusters[idx] @@ -121,7 +121,7 @@ cpdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] old_cluster_id = labels[far_idx] - for k in xrange(n_features): + for k in range(n_features): centers[new_cluster_id, k] = X[far_idx, k] * weight centers[old_cluster_id, k] -= X[far_idx, k] * weight @@ -138,8 +138,7 @@ cpdef void _relocate_empty_clusters_sparse(floating[::1] X_data, int[::1] labels): """Relocate centers which have no sample assigned to them""" cdef: - int[::1] empty_clusters = \ - np.where(np.equal(weight_in_clusters,0))[0].astype(np.int32) + int[::1] empty_clusters = np.where(np.equal(weight_in_clusters,0))[0].astype(np.int32) int n_empty = empty_clusters.shape[0] if n_empty == 0: @@ -152,21 +151,20 @@ cpdef void _relocate_empty_clusters_sparse(floating[::1] X_data, floating[::1] distances = np.zeros(n_samples, dtype=X_data.base.dtype) - for i in xrange(n_samples): + for i in range(n_samples): j = labels[i] - for k in xrange(X_indptr[i], X_indptr[i + 1]): + for k in range(X_indptr[i], X_indptr[i + 1]): x = (X_data[k] - centers[j, X_indices[k]]) distances[i] += x * x cdef: - int[::1] far_from_centers = \ - np.argpartition(distances, -n_empty)[-n_empty:].astype(np.int32) + int[::1] far_from_centers = np.argpartition(distances, -n_empty)[-n_empty:].astype(np.int32) int new_cluster_id, old_cluster_id, far_idx, idx floating weight if n_empty > 0: - for idx in xrange(n_empty): + for idx in range(n_empty): new_cluster_id = empty_clusters[idx] @@ -175,7 +173,7 @@ cpdef void _relocate_empty_clusters_sparse(floating[::1] X_data, old_cluster_id = labels[far_idx] - for k in xrange(X_indptr[far_idx], X_indptr[far_idx + 1]): + for k in range(X_indptr[far_idx], X_indptr[far_idx + 1]): centers[new_cluster_id, X_indices[k]] += X_data[k] * weight centers[old_cluster_id, X_indices[k]] -= X_data[k] * weight @@ -196,10 +194,10 @@ cpdef void _mean_and_center_shift(floating[:, ::1] centers_old, floating alpha, tmp, x # average new centers wrt sample weights - for j in xrange(n_clusters): + for j in range(n_clusters): if weight_in_clusters[j] > 0: alpha = 1.0 / weight_in_clusters[j] - for k in xrange(n_features): + for k in range(n_features): centers_new[j, k] *= alpha # compute shift distance between old and new centers diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 150c8d93acb42..5ce93567f8794 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -243,9 +243,9 @@ shape (n_clusters, n_clusters) # race conditions. if update_centers: with gil: - for j in xrange(n_clusters): + for j in range(n_clusters): weight_in_clusters[j] += weight_in_clusters_chunk[j] - for k in xrange(n_features): + for k in range(n_features): centers_new[j, k] += centers_new_chunk[j * n_features + k] free(weight_in_clusters_chunk) diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index 359e203c4c4dd..21f01e83a0e98 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -118,32 +118,24 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, # re-initialize all arrays at each iteration memset(¢ers_squared_norms[0], 0, n_clusters * sizeof(floating)) - for j in xrange(n_clusters): - for k in xrange(n_features): + for j in range(n_clusters): + for k in range(n_features): centers_squared_norms[j] += centers_new[j, k] * centers_new[j, k] if update_centers: - memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], - n_clusters * n_features * sizeof(floating)) - memset(¢ers_new[0, 0], 0, - n_clusters * n_features * sizeof(floating)) + memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) + memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) # set number of threads to be used by openmp num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() with nogil, parallel(num_threads=num_threads): - centers_new_chunk = \ - malloc(n_clusters * n_features * sizeof(floating)) - - weight_in_clusters_chunk = \ - malloc(n_clusters * sizeof(floating)) - - pairwise_distances_chunk = \ - malloc(n_samples_chunk * n_clusters * sizeof(floating)) - + # thread local buffers + centers_new_chunk = malloc(n_clusters * n_features * sizeof(floating)) + weight_in_clusters_chunk = malloc(n_clusters * sizeof(floating)) + pairwise_distances_chunk = malloc(n_samples_chunk * n_clusters * sizeof(floating)) # initialize local buffers - memset(centers_new_chunk, 0, - n_clusters * n_features * sizeof(floating)) + memset(centers_new_chunk, 0, n_clusters * n_features * sizeof(floating)) memset(weight_in_clusters_chunk, 0, n_clusters * sizeof(floating)) for chunk_idx in prange(n_chunks): @@ -171,11 +163,10 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, # race conditions. if update_centers: with gil: - for j in xrange(n_clusters): + for j in range(n_clusters): weight_in_clusters[j] += weight_in_clusters_chunk[j] - for k in xrange(n_features): - centers_new[j, k] += \ - centers_new_chunk[j * n_features + k] + for k in range(n_features): + centers_new[j, k] += centers_new_chunk[j * n_features + k] free(weight_in_clusters_chunk) free(centers_new_chunk) @@ -221,18 +212,18 @@ cdef void _update_chunk_dense(floating *X, # ||X - C||² = ||X||² - 2 X.C^T + ||C||², we only need to store # the - 2 X.C^T + ||C||² term since the argmin for a given sample only # depends on the centers. - for i in xrange(n_samples): - for j in xrange(n_clusters): + for i in range(n_samples): + for j in range(n_clusters): pairwise_distances[i * n_clusters + j] = centers_squared_norms[j] xgemm(trans_centers, trans_data, &n_clusters, &n_samples, &n_features, &alpha, centers_old, &n_features, X, &n_features, &beta, pairwise_distances, &n_clusters) - for i in xrange(n_samples): + for i in range(n_samples): min_sq_dist = pairwise_distances[i * n_clusters] best_cluster = 0 - for j in xrange(n_clusters): + for j in range(n_clusters): sq_dist = pairwise_distances[i * n_clusters + j] if sq_dist < min_sq_dist: min_sq_dist = sq_dist @@ -242,9 +233,8 @@ cdef void _update_chunk_dense(floating *X, if update_centers: weight_in_clusters[best_cluster] += sample_weight[i] - for k in xrange(n_features): - centers_new[best_cluster * n_features + k] += \ - X[i * n_features + k] * sample_weight[i] + for k in range(n_features): + centers_new[best_cluster * n_features + k] += X[i * n_features + k] * sample_weight[i] cpdef void _lloyd_iter_chunked_sparse(X, @@ -332,29 +322,23 @@ cpdef void _lloyd_iter_chunked_sparse(X, # re-initialize all arrays at each iteration memset(¢ers_squared_norms[0], 0, n_clusters * sizeof(floating)) - for j in xrange(n_clusters): - for k in xrange(n_features): + for j in range(n_clusters): + for k in range(n_features): centers_squared_norms[j] += centers_new[j, k] * centers_new[j, k] if update_centers: - memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], - n_clusters * n_features * sizeof(floating)) - memset(¢ers_new[0, 0], 0, - n_clusters * n_features * sizeof(floating)) + memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) + memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) # set number of threads to be used by openmp num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() with nogil, parallel(num_threads=num_threads): - centers_new_chunk = \ - malloc(n_clusters * n_features * sizeof(floating)) - - weight_in_clusters_chunk = \ - malloc(n_clusters * sizeof(floating)) - + # thread local buffers + centers_new_chunk = malloc(n_clusters * n_features * sizeof(floating)) + weight_in_clusters_chunk = malloc(n_clusters * sizeof(floating)) # initialize local buffers - memset(centers_new_chunk, 0, - n_clusters * n_features * sizeof(floating)) + memset(centers_new_chunk, 0, n_clusters * n_features * sizeof(floating)) memset(weight_in_clusters_chunk, 0, n_clusters * sizeof(floating)) for chunk_idx in prange(n_chunks): @@ -383,11 +367,10 @@ cpdef void _lloyd_iter_chunked_sparse(X, # race conditions. if update_centers: with gil: - for j in xrange(n_clusters): + for j in range(n_clusters): weight_in_clusters[j] += weight_in_clusters_chunk[j] - for k in xrange(n_features): - centers_new[j, k] += \ - centers_new_chunk[j * n_features + k] + for k in range(n_features): + centers_new[j, k] += centers_new_chunk[j * n_features + k] free(weight_in_clusters_chunk) free(centers_new_chunk) @@ -429,15 +412,14 @@ cdef void _update_chunk_sparse(floating *X_data, # XXX Precompute the pairwise distances matrix is not worth for sparse # currently. Should be tested when BLAS (sparse x dense) matrix # multiplication is available. - for i in xrange(n_samples): + for i in range(n_samples): min_sq_dist = max_floating best_cluster = 0 - for j in xrange(n_clusters): + for j in range(n_clusters): sq_dist = 0.0 - for k in xrange(X_indptr[i] - s, X_indptr[i + 1] - s): - sq_dist += \ - centers_old[j * n_features + X_indices[k]] * X_data[k] + for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): + sq_dist += centers_old[j * n_features + X_indices[k]] * X_data[k] # Instead of computing the full squared distance with each cluster, # ||X - C||² = ||X||² - 2 X.C^T + ||C||², we only need to compute @@ -452,6 +434,5 @@ cdef void _update_chunk_sparse(floating *X_data, if update_centers: weight_in_cluster[best_cluster] += sample_weight[i] - for k in xrange(X_indptr[i] - s, X_indptr[i + 1] - s): - centers_new[best_cluster * n_features + X_indices[k]] += \ - X_data[k] * sample_weight[i] + for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): + centers_new[best_cluster * n_features + X_indices[k]] += X_data[k] * sample_weight[i] From 0bcc1f1352ce56550c145722fbb6a9031f2a08aa Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 25 Oct 2018 12:37:31 +0200 Subject: [PATCH 012/163] cos --- sklearn/cluster/k_means_.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 311494b9a66e3..2c784ef9fb403 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -454,14 +454,12 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, .format(i, center_shift_tot, tol)) break - if center_shift_tot > 0: - # rerun E-step in case of non-convergence so that predicted labels - # match cluster centers - _elkan_iter_chunked_dense(X, sample_weight, centers, centers, - weight_in_clusters, center_half_distances, - distance_next_center, upper_bounds, - lower_bounds, labels, center_shift, n_jobs, - update_centers=False) + # rerun E-step so that predicted labels match cluster centers + _elkan_iter_chunked_dense(X, sample_weight, centers, centers, + weight_in_clusters, center_half_distances, + distance_next_center, upper_bounds, + lower_bounds, labels, center_shift, n_jobs, + update_centers=False) inertia = _inertia_dense(X, sample_weight, centers, labels) @@ -544,6 +542,7 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, # init centers = _init_centroids(X, n_clusters, init, random_state=random_state, x_squared_norms=x_squared_norms) + if verbose: print("Initialization complete") @@ -577,12 +576,10 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, .format(i, center_shift_tot, tol)) break - if center_shift_tot > 0: - # rerun E-step in case of non-convergence so that predicted labels - # match cluster centers - lloyd_iter(X, sample_weight, x_squared_norms, centers, centers, - centers_squared_norms, weight_in_clusters, labels, - center_shift, n_jobs, update_centers=False) + # rerun E-step so that predicted labels match cluster centers + lloyd_iter(X, sample_weight, x_squared_norms, centers, centers, + centers_squared_norms, weight_in_clusters, labels, + center_shift, n_jobs, update_centers=False) inertia = _inertia(X, sample_weight, centers, labels) @@ -625,15 +622,15 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers): center_shift = np.zeros_like(centers_squared_norms) if sp.issparse(X): - labels_centers = _lloyd_iter_chunked_sparse + _labels = _lloyd_iter_chunked_sparse _inertia = _inertia_sparse else: - labels_centers = _lloyd_iter_chunked_dense + _labels = _lloyd_iter_chunked_dense _inertia = _inertia_dense - labels_centers(X, sample_weight, x_squared_norms, centers, - centers, centers_squared_norms, weight_in_clusters, - labels, center_shift, update_centers=False) + _labels(X, sample_weight, x_squared_norms, centers, + centers, centers_squared_norms, weight_in_clusters, + labels, center_shift, update_centers=False) inertia = _inertia(X, sample_weight, centers, labels) From 107290e1ffff8eedbd9e090d8b797ecb50006dd5 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 29 Oct 2018 16:47:01 +0100 Subject: [PATCH 013/163] fix algorithm check --- sklearn/cluster/k_means_.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 2c784ef9fb403..8747f091f88f2 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -351,12 +351,14 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', x_squared_norms = row_norms(X, squared=True) best_labels, best_inertia, best_centers = None, None, None - if n_clusters == 1: - # elkan doesn't make sense for a single cluster, full will produce - # the right result. - algorithm = "full" + if algorithm == "auto": - algorithm = "full" if sp.issparse(X) else 'elkan' + algorithm = "full" if sp.issparse(X) else "elkan" + if algorithm == "elkan" and n_clusters == 1: + warnings.warns("algorithm='elkan' doesn't make sense for a single " + "cluster. Using 'full' instead.") + algorithm = "full" + if algorithm == "full": kmeans_single = _kmeans_single_lloyd elif algorithm == "elkan": From aac23505fe4360c6385254e8d221a3e351a50ef9 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 29 Oct 2018 17:34:59 +0100 Subject: [PATCH 014/163] typo --- sklearn/cluster/k_means_.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 8747f091f88f2..3a04fef1760ba 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -355,8 +355,8 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', if algorithm == "auto": algorithm = "full" if sp.issparse(X) else "elkan" if algorithm == "elkan" and n_clusters == 1: - warnings.warns("algorithm='elkan' doesn't make sense for a single " - "cluster. Using 'full' instead.") + warnings.warn("algorithm='elkan' doesn't make sense for a single " + "cluster. Using 'full' instead.", RuntimeWarning) algorithm = "full" if algorithm == "full": From 8e432be09dfd4e92bad2f57262feb4ca80138416 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 29 Oct 2018 17:35:25 +0100 Subject: [PATCH 015/163] deprecation precompute in tests --- sklearn/cluster/tests/test_k_means.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 8d213f4310acb..9e3d1271d3c70 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -291,8 +291,7 @@ def test_k_means_fortran_aligned_data(): X = np.asfortranarray([[0, 0], [0, 1], [0, 1]]) centers = np.array([[0, 0], [0, 1]]) labels = np.array([0, 1, 1]) - km = KMeans(n_init=1, init=centers, precompute_distances=False, - random_state=42, n_clusters=2) + km = KMeans(n_init=1, init=centers, random_state=42, n_clusters=2) km.fit(X) assert_array_almost_equal(km.cluster_centers_, centers) assert_array_equal(km.labels_, labels) From 4531fc6808aca3e6bd1ab92ed23f369939dbebee Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 31 Oct 2018 11:34:50 +0100 Subject: [PATCH 016/163] use libc FLT_MAX --- sklearn/cluster/_k_means_lloyd.pyx | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index 21f01e83a0e98..459788e07228f 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -4,7 +4,6 @@ import numpy as np cimport numpy as np -cimport cython cimport openmp from cython cimport floating from cython.parallel import prange, parallel @@ -12,6 +11,7 @@ from scipy.linalg.cython_blas cimport sgemm, dgemm from libc.math cimport sqrt from libc.stdlib cimport malloc, free from libc.string cimport memset, memcpy +from libc.float cimport DBL_MAX, FLT_MAX from ._k_means import (_relocate_empty_clusters_dense, _relocate_empty_clusters_sparse, @@ -21,11 +21,6 @@ from ._k_means import (_relocate_empty_clusters_dense, np.import_array() -cdef: - float MAX_FLT = np.finfo(np.float32).max - double MAX_DBL = np.finfo(np.float64).max - - cdef void xgemm(char *ta, char *tb, int *m, int *n, int *k, floating *alpha, floating *A, int *lda, floating *B, int *ldb, floating *beta, floating *C, int *ldc) nogil: @@ -406,7 +401,7 @@ cdef void _update_chunk_sparse(floating *X_data, cdef: floating sq_dist, min_sq_dist int i, j, k, best_cluster - floating max_floating = MAX_FLT if floating is float else MAX_DBL + floating max_floating = FLT_MAX if floating is float else DBL_MAX int s = X_indptr[0] # XXX Precompute the pairwise distances matrix is not worth for sparse From 52d8aba3499c6fcf0e8a80c9777ff8374deeb4ab Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 31 Oct 2018 11:35:42 +0100 Subject: [PATCH 017/163] setup unlik cblas --- sklearn/cluster/setup.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/sklearn/cluster/setup.py b/sklearn/cluster/setup.py index 75b3e355138e4..8690c9aa9afb1 100644 --- a/sklearn/cluster/setup.py +++ b/sklearn/cluster/setup.py @@ -1,7 +1,6 @@ # Author: Alexandre Gramfort # License: BSD 3 clause import os -from os.path import join import numpy @@ -36,16 +35,11 @@ def configuration(parent_package='', top_path=None): libraries=libraries) config.add_extension('_k_means_lloyd', - libraries=cblas_libs, sources=['_k_means_lloyd.pyx'], - include_dirs=[join('..', 'src', 'cblas'), - numpy.get_include(), - blas_info.pop('include_dirs', [])], + include_dirs=[numpy.get_include()], + libraries=libraries, extra_link_args=['-fopenmp'], - extra_compile_args=blas_info.pop( - 'extra_compile_args', []) + ['-fopenmp'], - **blas_info - ) + extra_compile_args=['-fopenmp']) config.add_extension('_k_means_elkan', sources=['_k_means_elkan.pyx'], From ff35b295dd3233b5453b6bc07564c15417968ce3 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 29 Nov 2018 11:25:21 +0100 Subject: [PATCH 018/163] remove unecessary blas stuff from setup --- sklearn/cluster/setup.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sklearn/cluster/setup.py b/sklearn/cluster/setup.py index 8690c9aa9afb1..fb83d38dd1e53 100644 --- a/sklearn/cluster/setup.py +++ b/sklearn/cluster/setup.py @@ -4,17 +4,12 @@ import numpy -from sklearn._build_utils import get_blas_info - def configuration(parent_package='', top_path=None): from numpy.distutils.misc_util import Configuration - cblas_libs, blas_info = get_blas_info() - libraries = [] if os.name == 'posix': - cblas_libs.append('m') libraries.append('m') config = Configuration('cluster', parent_package, top_path) From 286aed44d12aed122ff94de61ad91577cabc0f19 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 29 Nov 2018 13:37:20 +0100 Subject: [PATCH 019/163] Add _clibs module to limit number of threads for C-libs --- sklearn/cluster/k_means_.py | 9 + sklearn/utils/_clibs.py | 344 ++++++++++++++++++++++++++++++ sklearn/utils/tests/test_clibs.py | 66 ++++++ 3 files changed, 419 insertions(+) create mode 100644 sklearn/utils/_clibs.py create mode 100644 sklearn/utils/tests/test_clibs.py diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 3a04fef1760ba..2a2bac687ee35 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -28,6 +28,7 @@ from ..utils import check_random_state from ..utils.validation import check_is_fitted from ..utils.validation import FLOAT_DTYPES +from ..utils._clibs import get_thread_limits, limit_threads_clibs from ..utils import effective_n_jobs from ..externals.six import string_types from ..exceptions import ConvergenceWarning @@ -370,6 +371,11 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', n_jobs_ = -1 if n_jobs is None else effective_n_jobs(n_jobs) seeds = random_state.randint(np.iinfo(np.int32).max, size=n_init) + # limit number of threads in second level of nested parallelism (i.e. BLAS) + # to avoid oversubsciption + limits = get_thread_limits(reload_clib=True) + limit_threads_clibs(limits=1, subset="blas") + for seed in seeds: # run a k-means once labels, inertia, centers, n_iter_ = kmeans_single( @@ -383,6 +389,9 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', best_inertia = inertia best_n_iter = n_iter_ + # release the limit on threads number and reset to initial value + limit_threads_clibs(limits=limits) + if not sp.issparse(X): if not copy_x: X += X_mean diff --git a/sklearn/utils/_clibs.py b/sklearn/utils/_clibs.py new file mode 100644 index 0000000000000..de5f00a52ebc2 --- /dev/null +++ b/sklearn/utils/_clibs.py @@ -0,0 +1,344 @@ +""" +This module provides utilities to load C-libraries that relies on thread +pools and limit the maximal number of thread that can be used. +""" + +# This code is adapted from code by Thomas Moreau available at +# https://github.com/tomMoral/loky + + +import sys +import os +import threading +import ctypes +from ctypes.util import find_library + + +# Structure to cast the info on dynamically loaded library. See +# https://linux.die.net/man/3/dl_iterate_phdr for more details. +UINT_SYSTEM = ctypes.c_uint64 if sys.maxsize > 2**32 else ctypes.c_uint32 +UINT_HALF_SYSTEM = ctypes.c_uint32 if sys.maxsize > 2**32 else ctypes.c_uint16 + + +class dl_phdr_info(ctypes.Structure): + _fields_ = [ + ("dlpi_addr", UINT_SYSTEM), # Base address of object + ("dlpi_name", ctypes.c_char_p), # path to the library + ("dlpi_phdr", ctypes.c_void_p), # pointer on dlpi_headers + ("dlpi_phnum", UINT_HALF_SYSTEM) # number of element in dlpi_phdr + ] + + +class _CLibsWrapper: + # Wrapper around classic C-libraries for scientific computations to set and + # get the maximum number of threads they are allowed to used for inner + # parallelism. + + # Supported C-libraries for this wrapper, index with their name. The items + # hold the name of the library file and the functions to call. + SUPPORTED_CLIBS = { + "openmp_intel": ( + "libiomp", "omp_set_num_threads", "omp_get_max_threads"), + "openmp_gnu": ( + "libgomp", "omp_set_num_threads", "omp_get_max_threads"), + "openmp_llvm": ( + "libomp", "omp_set_num_threads", "omp_get_max_threads"), + "openmp_win32": ( + "vcomp", "omp_set_num_threads", "omp_get_max_threads"), + "openblas": ( + "libopenblas", "openblas_set_num_threads", + "openblas_get_num_threads"), + "mkl": ( + "libmkl_rt", "MKL_Set_Num_Threads", "MKL_Get_Max_Threads"), + "mkl_win32": ( + "mkl_rt", "MKL_Set_Num_Threads", "MKL_Get_Max_Threads")} + + cls_thread_locals = threading.local() + + def __init__(self): + self._load() + + def _load(self): + for clib, (module_name, _, _) in self.SUPPORTED_CLIBS.items(): + setattr(self, clib, self._load_lib(module_name)) + + def _unload(self): + for clib, (module_name, _, _) in self.SUPPORTED_CLIBS.items(): + delattr(self, clib) + + def limit_threads_clibs(self, limits=1, subset=None): + """Limit maximal number of threads used by supported C-libraries""" + if isinstance(limits, int): + if subset in ("all", None): + clibs = self.SUPPORTED_CLIBS.keys() + elif subset == "blas": + clibs = ("openblas", "mkl", "mkl_win32") + elif subset == "openmp": + clibs = (c for c in self.SUPPORTED_CLIBS if "openmp" in c) + else: + raise ValueError("subset must be either 'all', 'blas' or " + "'openmp'. Got {} instead.".format(subset)) + limits = {clib: limits for clib in clibs} + + if not isinstance(limits, dict): + raise TypeError("limits must either be an int or a dict. Got {} " + "instead".format(type(limits))) + + dynamic_threadpool_size = {} + self._load() + for clib, (_, _set, _) in self.SUPPORTED_CLIBS.items(): + if clib in limits: + module = getattr(self, clib, None) + if module is not None: + _set = getattr(module, _set) + _set(limits[clib]) + dynamic_threadpool_size[clib] = True + else: + dynamic_threadpool_size[clib] = False + self._unload() + return dynamic_threadpool_size + + def get_thread_limits(self): + """Return maximal number of threads available for supported C-libraries + """ + limits = {} + self._load() + for clib, (_, _, _get) in self.SUPPORTED_CLIBS.items(): + module = getattr(self, clib, None) + if module is not None: + _get = getattr(module, _get) + limits[clib] = _get() + else: + limits[clib] = None + self._unload() + return limits + + def _load_lib(self, module_name): + """Return a binder on module_name by looping through loaded libraries + """ + if sys.platform == "darwin": + return self._find_with_clibs_dyld(module_name) + elif sys.platform == "win32": + return self._find_with_clibs_enum_process_module_ex(module_name) + return self._find_with_clibs_dl_iterate_phdr(module_name) + + def _find_with_clibs_dl_iterate_phdr(self, module_name): + """Return a binder on module_name by looping through loaded libraries + + This function is expected to work on POSIX system only. + This code is adapted from code by Intel developper @anton-malakhov + available at https://github.com/IntelPython/smp + + Copyright (c) 2017, Intel Corporation published under the BSD 3-Clause + license + """ + self.cls_thread_locals._module_path = None + + libc = self._get_libc() + if not hasattr(libc, "dl_iterate_phdr"): + return + + # Callback function for `dl_iterate_phdr` which is called for every + # module loaded in the current process until it returns 1. + def match_module_callback(info, size, module_name): + + # recast the name of the module as a string + module_name = ctypes.string_at(module_name).decode('utf-8') + + # Get the name of the current library + module_path = info.contents.dlpi_name + + # If the current library is the one we are looking for, store the + # path and return 1 to stop the loop in `dl_iterate_phdr`. + if module_path: + module_path = module_path.decode("utf-8") + if os.path.basename(module_path).startswith(module_name): + self.cls_thread_locals._module_path = module_path + return 1 + return 0 + + c_func_signature = ctypes.CFUNCTYPE( + ctypes.c_int, # Return type + ctypes.POINTER(dl_phdr_info), ctypes.c_size_t, ctypes.c_char_p) + c_match_module_callback = c_func_signature(match_module_callback) + + data = ctypes.c_char_p(module_name.encode('utf-8')) + res = libc.dl_iterate_phdr(c_match_module_callback, data) + if res == 1: + return ctypes.CDLL(self.cls_thread_locals._module_path) + + def _find_with_clibs_dyld(self, module_name): + """Return a binder on module_name by looping through loaded libraries + + This function is expected to work on OSX system only + """ + libc = self._get_libc() + if not hasattr(libc, "_dyld_image_count"): + return + + found_module_path = None + + n_dyld = libc._dyld_image_count() + libc._dyld_get_image_name.restype = ctypes.c_char_p + + for i in range(n_dyld): + module_path = ctypes.string_at(libc._dyld_get_image_name(i)) + module_path = module_path.decode("utf-8") + if os.path.basename(module_path).startswith(module_name): + found_module_path = module_path + + if found_module_path: + return ctypes.CDLL(found_module_path) + + def _find_with_clibs_enum_process_module_ex(self, module_name): + """Return a binder on module_name by looping through loaded libraries + + This function is expected to work on windows system only. + This code is adapted from code by Philipp Hagemeister @phihag available + at https://stackoverflow.com/questions/17474574 + """ + from ctypes.wintypes import DWORD, HMODULE, MAX_PATH + + PROCESS_QUERY_INFORMATION = 0x0400 + PROCESS_VM_READ = 0x0010 + + LIST_MODULES_ALL = 0x03 + + Psapi = self._get_windll('Psapi') + Kernel32 = self._get_windll('kernel32') + + hProcess = Kernel32.OpenProcess( + PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, + False, os.getpid()) + if not hProcess: + raise OSError('Could not open PID %s' % os.getpid()) + + found_module_path = None + try: + buf_count = 256 + needed = DWORD() + # Grow the buffer until it becomes large enough to hold all the + # module headers + while True: + buf = (HMODULE * buf_count)() + buf_size = ctypes.sizeof(buf) + if not Psapi.EnumProcessModulesEx( + hProcess, ctypes.byref(buf), buf_size, + ctypes.byref(needed), LIST_MODULES_ALL): + raise OSError('EnumProcessModulesEx failed') + if buf_size >= needed.value: + break + buf_count = needed.value // (buf_size // buf_count) + + count = needed.value // (buf_size // buf_count) + hModules = map(HMODULE, buf[:count]) + + # Loop through all the module headers and get the module file name + buf = ctypes.create_unicode_buffer(MAX_PATH) + nSize = DWORD() + for hModule in hModules: + if not Psapi.GetModuleFileNameExW( + hProcess, hModule, ctypes.byref(buf), + ctypes.byref(nSize)): + raise OSError('GetModuleFileNameEx failed') + module_path = buf.value + module_basename = os.path.basename(module_path).lower() + if module_basename.startswith(module_name): + found_module_path = module_path + finally: + Kernel32.CloseHandle(hProcess) + + if found_module_path: + return ctypes.CDLL(found_module_path) + + def _get_libc(self): + if not hasattr(self, "libc"): + libc_name = find_library("c") + if libc_name is None: + self.libc = None + self.libc = ctypes.CDLL(libc_name) + + return self.libc + + def _get_windll(self, dll_name): + if not hasattr(self, dll_name): + setattr(self, dll_name, ctypes.WinDLL("{}.dll".format(dll_name))) + + return getattr(self, dll_name) + + +_clibs_wrapper = None + + +def _get_wrapper(reload_clib=False): + """Helper function to only create one wrapper per thread.""" + global _clibs_wrapper + if _clibs_wrapper is None: + _clibs_wrapper = _CLibsWrapper() + if reload_clib: + _clibs_wrapper._load() + return _clibs_wrapper + + +def limit_threads_clibs(limits=1, subset=None, reload_clib=False): + """Limit the number of threads available for threadpools in supported C-lib + + Set the maximal number of thread that can be used in thread pools used in + the supported C-libraries. This function works for libraries that are + already loaded in the interpreter and can be changed dynamically. + + Parameters + ---------- + limits : int or dict, (default=1) + Maximum number of thread that can be used in thread pools + + If int, sets the maximum number of thread to `limits` for each C-lib + selected by `subset`. + + If dict(supported_libraries: max_threads), sets a custom maximum number + of thread for each C-lib. + + subset : string or None, optional (default="all") + Subset of C-libs to limit. Used only if `limits` is an int + + "all" : limit all supported C-libs. + + "blas" : limit only BLAS supported C-libs. + + "openmp" : limit only OpenMP supported C-libs. It can affect the number + of threads used by the BLAS C-libs if they rely on OpenMP. + + reload_clib : bool, (default=False) + If `reload_clib` is `True`, first loop through the loaded libraries to + ensure that this function is called on all available libraries. + + Returns + ------- + dynamic_threadpool_size : dict + contains pairs `('clib': boolean)` which are True if `clib` have been + found and can be used to scale the maximal number of threads + dynamically. + """ + wrapper = _get_wrapper(reload_clib) + return wrapper.limit_threads_clibs(limits, subset) + + +def get_thread_limits(reload_clib=True): + """Return maximal thread number for threadpools in supported C-lib + + Parameters + ---------- + reload_clib : bool, (default=True) + If `reload_clib` is `True`, first loop through the loaded libraries to + ensure that this function is called on all available libraries. + + Returns + ------- + thread_limits : dict + Contains the maximal number of threads that can be used in supported + libraries or None when the library is not available. The key of the + dictionary are "openmp_gnu", "openmp_intel", "openmp_win32", + "openmp_llvm", "openblas", "mkl" and "mkl_win32". + """ + wrapper = _get_wrapper(reload_clib) + return wrapper.get_thread_limits() diff --git a/sklearn/utils/tests/test_clibs.py b/sklearn/utils/tests/test_clibs.py new file mode 100644 index 0000000000000..c382ef0d2a9ec --- /dev/null +++ b/sklearn/utils/tests/test_clibs.py @@ -0,0 +1,66 @@ +import os + +import pytest + +from sklearn.utils._clibs import get_thread_limits, limit_threads_clibs +from sklearn.utils._clibs import _CLibsWrapper + + +@pytest.mark.parametrize("clib", _CLibsWrapper.SUPPORTED_CLIBS) +def test_limit_threads_clib_dict(clib): + old_limits = get_thread_limits() + + if old_limits[clib] is not None: + dynamic_scaling = limit_threads_clibs(limits={clib: 1}) + assert get_thread_limits()[clib] == 1 + assert dynamic_scaling[clib] + + limit_threads_clibs(limits={clib: 3}) + new_limits = get_thread_limits() + assert new_limits[clib] in (3, os.cpu_count(), os.cpu_count() / 2) + + limit_threads_clibs(limits=old_limits) + new_limits = get_thread_limits() + assert new_limits[clib] == old_limits[clib] + + +@pytest.mark.parametrize("subset", ("all", "blas", "openmp")) +def test_limit_threads_clib_subset(subset): + if subset == "all": + clibs = _CLibsWrapper.SUPPORTED_CLIBS.keys() + elif subset == "blas": + clibs = ("openblas", "mkl", "mkl_win32") + elif subset == "openmp": + clibs = (c for c in _CLibsWrapper.SUPPORTED_CLIBS if "openmp" in c) + + old_limits = get_thread_limits() + + dynamic_scaling = limit_threads_clibs(limits=1, subset=subset) + new_limits = get_thread_limits() + for clib in clibs: + if old_limits[clib] is not None: + assert new_limits[clib] == 1 + assert dynamic_scaling[clib] + + limit_threads_clibs(limits=3, subset=subset) + new_limits = get_thread_limits() + for clib in clibs: + if old_limits[clib] is not None: + assert new_limits[clib] in (3, os.cpu_count(), os.cpu_count() / 2) + + limit_threads_clibs(limits=old_limits) + new_limits = get_thread_limits() + for clib in clibs: + if old_limits[clib] is not None: + assert new_limits[clib] == old_limits[clib] + + +def test_limit_threads_clib_bad_input(): + with pytest.raises(ValueError, + match="subset must be either 'all', 'blas' " + "or 'openmp'"): + limit_threads_clibs(limits=1, subset="wrong") + + with pytest.raises(TypeError, + match="limits must either be an int or a dict"): + limit_threads_clibs(limits=(1, 2, 3)) From e720abe71ada592b945419ff4898528f5ddb8d1d Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 29 Nov 2018 13:58:21 +0100 Subject: [PATCH 020/163] fix merge conflict --- .circleci/config.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index f242e4a516edb..550c1219f2fba 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -103,6 +103,7 @@ jobs: command: | if [[ "${CIRCLE_BRANCH}" =~ ^master$|^[0-9]+\.[0-9]+\.X$ ]]; then bash build_tools/circle/push_doc.sh doc/_build/html/stable +<<<<<<< 16b3f13d9c9e44d9fb329800ddc37e42b000ffd6 <<<<<<< 6c0faf614b525bad520269e28ec684b44c00c22a fi @@ -134,3 +135,6 @@ workflows: ======= fi >>>>>>> drop python 2 CI +======= + fi +>>>>>>> fix merge conflict From 5d82b8d16553acea23c3023aca8c9860523298a4 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 29 Nov 2018 14:55:16 +0100 Subject: [PATCH 021/163] fix import deprecated --- sklearn/cluster/k_means_.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 2a2bac687ee35..5a10983004ce9 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -29,7 +29,7 @@ from ..utils.validation import check_is_fitted from ..utils.validation import FLOAT_DTYPES from ..utils._clibs import get_thread_limits, limit_threads_clibs -from ..utils import effective_n_jobs +from ..utils._joblib import effective_n_jobs from ..externals.six import string_types from ..exceptions import ConvergenceWarning from ._k_means import (_inertia_dense, @@ -373,7 +373,7 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', # limit number of threads in second level of nested parallelism (i.e. BLAS) # to avoid oversubsciption - limits = get_thread_limits(reload_clib=True) + limits = get_thread_limits() limit_threads_clibs(limits=1, subset="blas") for seed in seeds: From 2368f706b89252ab842cb2f01f58b8f62980164c Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 29 Nov 2018 16:15:21 +0100 Subject: [PATCH 022/163] try to fix clib tests ?? --- sklearn/utils/_clibs.py | 2 ++ sklearn/utils/tests/test_clibs.py | 12 ++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sklearn/utils/_clibs.py b/sklearn/utils/_clibs.py index de5f00a52ebc2..16f8069ad4572 100644 --- a/sklearn/utils/_clibs.py +++ b/sklearn/utils/_clibs.py @@ -95,6 +95,8 @@ def limit_threads_clibs(self, limits=1, subset=None): dynamic_threadpool_size[clib] = True else: dynamic_threadpool_size[clib] = False + else: + dynamic_threadpool_size[clib] = False self._unload() return dynamic_threadpool_size diff --git a/sklearn/utils/tests/test_clibs.py b/sklearn/utils/tests/test_clibs.py index c382ef0d2a9ec..192a5f95417cf 100644 --- a/sklearn/utils/tests/test_clibs.py +++ b/sklearn/utils/tests/test_clibs.py @@ -12,15 +12,15 @@ def test_limit_threads_clib_dict(clib): if old_limits[clib] is not None: dynamic_scaling = limit_threads_clibs(limits={clib: 1}) - assert get_thread_limits()[clib] == 1 + assert get_thread_limits(reload_clib=False)[clib] == 1 assert dynamic_scaling[clib] limit_threads_clibs(limits={clib: 3}) - new_limits = get_thread_limits() + new_limits = get_thread_limits(reload_clib=False) assert new_limits[clib] in (3, os.cpu_count(), os.cpu_count() / 2) limit_threads_clibs(limits=old_limits) - new_limits = get_thread_limits() + new_limits = get_thread_limits(reload_clib=False) assert new_limits[clib] == old_limits[clib] @@ -36,20 +36,20 @@ def test_limit_threads_clib_subset(subset): old_limits = get_thread_limits() dynamic_scaling = limit_threads_clibs(limits=1, subset=subset) - new_limits = get_thread_limits() + new_limits = get_thread_limits(reload_clib=False) for clib in clibs: if old_limits[clib] is not None: assert new_limits[clib] == 1 assert dynamic_scaling[clib] limit_threads_clibs(limits=3, subset=subset) - new_limits = get_thread_limits() + new_limits = get_thread_limits(reload_clib=False) for clib in clibs: if old_limits[clib] is not None: assert new_limits[clib] in (3, os.cpu_count(), os.cpu_count() / 2) limit_threads_clibs(limits=old_limits) - new_limits = get_thread_limits() + new_limits = get_thread_limits(reload_clib=False) for clib in clibs: if old_limits[clib] is not None: assert new_limits[clib] == old_limits[clib] From 4d960a3c2877aa62e8474be90de46b531685f1f9 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 29 Nov 2018 16:42:44 +0100 Subject: [PATCH 023/163] doesn't work... revert --- sklearn/utils/tests/test_clibs.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/utils/tests/test_clibs.py b/sklearn/utils/tests/test_clibs.py index 192a5f95417cf..c382ef0d2a9ec 100644 --- a/sklearn/utils/tests/test_clibs.py +++ b/sklearn/utils/tests/test_clibs.py @@ -12,15 +12,15 @@ def test_limit_threads_clib_dict(clib): if old_limits[clib] is not None: dynamic_scaling = limit_threads_clibs(limits={clib: 1}) - assert get_thread_limits(reload_clib=False)[clib] == 1 + assert get_thread_limits()[clib] == 1 assert dynamic_scaling[clib] limit_threads_clibs(limits={clib: 3}) - new_limits = get_thread_limits(reload_clib=False) + new_limits = get_thread_limits() assert new_limits[clib] in (3, os.cpu_count(), os.cpu_count() / 2) limit_threads_clibs(limits=old_limits) - new_limits = get_thread_limits(reload_clib=False) + new_limits = get_thread_limits() assert new_limits[clib] == old_limits[clib] @@ -36,20 +36,20 @@ def test_limit_threads_clib_subset(subset): old_limits = get_thread_limits() dynamic_scaling = limit_threads_clibs(limits=1, subset=subset) - new_limits = get_thread_limits(reload_clib=False) + new_limits = get_thread_limits() for clib in clibs: if old_limits[clib] is not None: assert new_limits[clib] == 1 assert dynamic_scaling[clib] limit_threads_clibs(limits=3, subset=subset) - new_limits = get_thread_limits(reload_clib=False) + new_limits = get_thread_limits() for clib in clibs: if old_limits[clib] is not None: assert new_limits[clib] in (3, os.cpu_count(), os.cpu_count() / 2) limit_threads_clibs(limits=old_limits) - new_limits = get_thread_limits(reload_clib=False) + new_limits = get_thread_limits() for clib in clibs: if old_limits[clib] is not None: assert new_limits[clib] == old_limits[clib] From afc306b67ae7cb7e300f1b9ec02529dee21c44a2 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 30 Nov 2018 16:41:37 +0100 Subject: [PATCH 024/163] add header for _k_means to export cdef funcs --- sklearn/cluster/_k_means.pxd | 30 ++++++++++++++++++++++++++++ sklearn/cluster/_k_means.pyx | 32 +++++++++++++++--------------- sklearn/cluster/_k_means_elkan.pyx | 2 +- sklearn/cluster/_k_means_lloyd.pyx | 6 +++--- 4 files changed, 50 insertions(+), 20 deletions(-) create mode 100644 sklearn/cluster/_k_means.pxd diff --git a/sklearn/cluster/_k_means.pxd b/sklearn/cluster/_k_means.pxd new file mode 100644 index 0000000000000..7edb31597b532 --- /dev/null +++ b/sklearn/cluster/_k_means.pxd @@ -0,0 +1,30 @@ +from cython cimport floating +cimport numpy as np + + +cdef void _relocate_empty_clusters_dense( + np.ndarray[floating, ndim=2, mode='c'], + floating[::1], + floating[:, ::1], + floating[::1], + int[::1] +) + + +cdef void _relocate_empty_clusters_sparse( + floating[::1], + int[::1], + int[::1], + floating[::1], + floating[:, ::1], + floating[::1], + int[::1] +) + + +cdef void _mean_and_center_shift( + floating[:, ::1], + floating[:, ::1], + floating[::1], + floating[::1] +) \ No newline at end of file diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 7e619532daceb..04502bf36c88c 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -88,11 +88,11 @@ cpdef floating _inertia_sparse(X, return inertia -cpdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] X, - floating[::1] sample_weight, - floating[:, ::1] centers, - floating[::1] weight_in_clusters, - int[::1] labels): +cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] X, + floating[::1] sample_weight, + floating[:, ::1] centers, + floating[::1] weight_in_clusters, + int[::1] labels): """Relocate centers which have no sample assigned to them""" cdef: int[::1] empty_clusters = np.where(np.equal(weight_in_clusters,0))[0].astype(np.int32) @@ -129,13 +129,13 @@ cpdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] weight_in_clusters[old_cluster_id] -= weight -cpdef void _relocate_empty_clusters_sparse(floating[::1] X_data, - int[::1] X_indices, - int[::1] X_indptr, - floating[::1] sample_weight, - floating[:, ::1] centers, - floating[::1] weight_in_clusters, - int[::1] labels): +cdef void _relocate_empty_clusters_sparse(floating[::1] X_data, + int[::1] X_indices, + int[::1] X_indptr, + floating[::1] sample_weight, + floating[:, ::1] centers, + floating[::1] weight_in_clusters, + int[::1] labels): """Relocate centers which have no sample assigned to them""" cdef: int[::1] empty_clusters = np.where(np.equal(weight_in_clusters,0))[0].astype(np.int32) @@ -181,10 +181,10 @@ cpdef void _relocate_empty_clusters_sparse(floating[::1] X_data, weight_in_clusters[old_cluster_id] -= weight -cpdef void _mean_and_center_shift(floating[:, ::1] centers_old, - floating[:, ::1] centers_new, - floating[::1] weight_in_clusters, - floating[::1] center_shift): +cdef void _mean_and_center_shift(floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[::1] weight_in_clusters, + floating[::1] center_shift): """Average new centers wrt weights and compute center shift""" cdef: int n_clusters = centers_old.shape[0] diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 5ce93567f8794..a53089ae77e49 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -15,7 +15,7 @@ from libc.stdlib cimport malloc, free from libc.string cimport memset, memcpy from ..metrics import euclidean_distances -from ._k_means import _relocate_empty_clusters_dense, _mean_and_center_shift +from ._k_means cimport _relocate_empty_clusters_dense, _mean_and_center_shift np.import_array() diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index 459788e07228f..3fa7449603ea4 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -13,9 +13,9 @@ from libc.stdlib cimport malloc, free from libc.string cimport memset, memcpy from libc.float cimport DBL_MAX, FLT_MAX -from ._k_means import (_relocate_empty_clusters_dense, - _relocate_empty_clusters_sparse, - _mean_and_center_shift) +from ._k_means cimport (_relocate_empty_clusters_dense, + _relocate_empty_clusters_sparse, + _mean_and_center_shift) np.import_array() From 9ffb9caefce51787e4d5df429f8e398a95d0483e Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 5 Dec 2018 14:04:05 +0100 Subject: [PATCH 025/163] calloc instead of malloc --- sklearn/cluster/_k_means_elkan.pyx | 11 ++++------- sklearn/cluster/_k_means_lloyd.pyx | 20 +++++++------------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index a53089ae77e49..20d38fb1dc4dc 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -11,7 +11,7 @@ cimport openmp from cython cimport floating from cython.parallel import prange, parallel from libc.math cimport sqrt -from libc.stdlib cimport malloc, free +from libc.stdlib cimport calloc, free from libc.string cimport memset, memcpy from ..metrics import euclidean_distances @@ -211,12 +211,9 @@ shape (n_clusters, n_clusters) with nogil, parallel(num_threads=num_threads): # thread local buffers - centers_new_chunk = malloc(n_clusters * n_features * sizeof(floating)) - weight_in_clusters_chunk = malloc(n_clusters * sizeof(floating)) - # initialize local buffers - memset(centers_new_chunk, 0, n_clusters * n_features * sizeof(floating)) - memset(weight_in_clusters_chunk, 0, n_clusters * sizeof(floating)) - + centers_new_chunk = calloc(n_clusters * n_features, sizeof(floating)) + weight_in_clusters_chunk = calloc(n_clusters, sizeof(floating)) + for chunk_idx in prange(n_chunks): if n_samples_r > 0 and chunk_idx == n_chunks - 1: n_samples_chunk_eff = n_samples_r diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index 3fa7449603ea4..a4bb0f61b60a9 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -9,7 +9,7 @@ from cython cimport floating from cython.parallel import prange, parallel from scipy.linalg.cython_blas cimport sgemm, dgemm from libc.math cimport sqrt -from libc.stdlib cimport malloc, free +from libc.stdlib cimport calloc, free from libc.string cimport memset, memcpy from libc.float cimport DBL_MAX, FLT_MAX @@ -126,13 +126,10 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() with nogil, parallel(num_threads=num_threads): # thread local buffers - centers_new_chunk = malloc(n_clusters * n_features * sizeof(floating)) - weight_in_clusters_chunk = malloc(n_clusters * sizeof(floating)) - pairwise_distances_chunk = malloc(n_samples_chunk * n_clusters * sizeof(floating)) - # initialize local buffers - memset(centers_new_chunk, 0, n_clusters * n_features * sizeof(floating)) - memset(weight_in_clusters_chunk, 0, n_clusters * sizeof(floating)) - + centers_new_chunk = calloc(n_clusters * n_features, sizeof(floating)) + weight_in_clusters_chunk = calloc(n_clusters, sizeof(floating)) + pairwise_distances_chunk = calloc(n_samples_chunk * n_clusters, sizeof(floating)) + for chunk_idx in prange(n_chunks): if n_samples_r > 0 and chunk_idx == n_chunks - 1: n_samples_chunk_eff = n_samples_r @@ -330,11 +327,8 @@ cpdef void _lloyd_iter_chunked_sparse(X, num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() with nogil, parallel(num_threads=num_threads): # thread local buffers - centers_new_chunk = malloc(n_clusters * n_features * sizeof(floating)) - weight_in_clusters_chunk = malloc(n_clusters * sizeof(floating)) - # initialize local buffers - memset(centers_new_chunk, 0, n_clusters * n_features * sizeof(floating)) - memset(weight_in_clusters_chunk, 0, n_clusters * sizeof(floating)) + centers_new_chunk = calloc(n_clusters * n_features, sizeof(floating)) + weight_in_clusters_chunk = calloc(n_clusters, sizeof(floating)) for chunk_idx in prange(n_chunks): if n_samples_r > 0 and chunk_idx == n_chunks - 1: From aced525ecbccfa3edad3c5fd6dc7b2671d08dcef Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 13 Dec 2018 17:20:12 +0100 Subject: [PATCH 026/163] tst build --- .travis.yml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/.travis.yml b/.travis.yml index 2926f2df560ba..31d903a36ac93 100644 --- a/.travis.yml +++ b/.travis.yml @@ -48,7 +48,17 @@ matrix: NUMPY_VERSION="1.11.0" SCIPY_VERSION="0.17.0" CYTHON_VERSION="*" PILLOW_VERSION="4.0.0" COVERAGE=true if: type != cron +<<<<<<< 91f39bd55a48e487720306b974f29c7de46d2209 # Linux environment to test the latest available dependencies and MKL. +======= + # Python 3.5 build tst + - env: DISTRIB="conda" PYTHON_VERSION="3.5" INSTALL_MKL="false" + NUMPY_VERSION="1.14" SCIPY_VERSION="1.0" CYTHON_VERSION="0.25.2" + PILLOW_VERSION="4.0.0" COVERAGE=true + SKLEARN_SITE_JOBLIB=1 JOBLIB_VERSION="0.11" + if: type != cron + # This environment tests the latest available dependencies. +>>>>>>> tst build # It runs tests requiring pandas and PyAMG. # It also runs with the site joblib instead of the vendored copy of joblib. - env: DISTRIB="conda" PYTHON_VERSION="*" INSTALL_MKL="true" From 497f8990fc3ab8ce32f6fcaa885877021ed8b134 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 14 Dec 2018 13:58:33 +0100 Subject: [PATCH 027/163] add get_openblas_version to clibs and skip tests with old openblas --- sklearn/utils/_clibs.py | 30 ++++++++++++++++++++++++++++++ sklearn/utils/tests/test_clibs.py | 27 ++++++++++++++++++++++++--- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/_clibs.py b/sklearn/utils/_clibs.py index 16f8069ad4572..73be2f6aeb5df 100644 --- a/sklearn/utils/_clibs.py +++ b/sklearn/utils/_clibs.py @@ -115,6 +115,17 @@ def get_thread_limits(self): self._unload() return limits + def get_openblas_version(self): + module = getattr(self, "openblas", None) + if module is not None: + get_config = getattr(module, "openblas_get_config") + get_config.restype = ctypes.c_char_p + config = get_config().split() + if config[0] == b"OpenBLAS": + return config[1].decode('utf-8') + return + return + def _load_lib(self, module_name): """Return a binder on module_name by looping through loaded libraries """ @@ -344,3 +355,22 @@ def get_thread_limits(reload_clib=True): """ wrapper = _get_wrapper(reload_clib) return wrapper.get_thread_limits() + + +def get_openblas_version(reload_clib=True): + """Return the OpenBLAS version + + Parameters + ---------- + reload_clib : bool, (default=True) + If `reload_clib` is `True`, first loop through the loaded libraries to + ensure that this function is called on all available libraries. + + Returns + ------- + version : string or None + None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS + did not expose it's verion before that. + """ + wrapper = _get_wrapper(reload_clib) + return wrapper.get_openblas_version() diff --git a/sklearn/utils/tests/test_clibs.py b/sklearn/utils/tests/test_clibs.py index c382ef0d2a9ec..17dd78c16ea92 100644 --- a/sklearn/utils/tests/test_clibs.py +++ b/sklearn/utils/tests/test_clibs.py @@ -2,12 +2,25 @@ import pytest +from sklearn.utils.testing import SkipTest from sklearn.utils._clibs import get_thread_limits, limit_threads_clibs +from sklearn.utils._clibs import get_openblas_version from sklearn.utils._clibs import _CLibsWrapper +SKIP_OPENBLAS = get_openblas_version() is None + + @pytest.mark.parametrize("clib", _CLibsWrapper.SUPPORTED_CLIBS) def test_limit_threads_clib_dict(clib): + # Check that the number of threads used by the multithreaded C-libs can be + # modified dynamically. + + if clib is "openblas" and SKIP_OPENBLAS: + raise SkipTest("Possible bug in getting maximum number of threads with" + " OpenBLAS < 0.2.16 and OpenBLAS does not expose it's " + "version before 0.3.4.") + old_limits = get_thread_limits() if old_limits[clib] is not None: @@ -26,12 +39,18 @@ def test_limit_threads_clib_dict(clib): @pytest.mark.parametrize("subset", ("all", "blas", "openmp")) def test_limit_threads_clib_subset(subset): + # Check that the number of threads used by the multithreaded C-libs can be + # modified dynamically. + if subset == "all": - clibs = _CLibsWrapper.SUPPORTED_CLIBS.keys() + clibs = list(_CLibsWrapper.SUPPORTED_CLIBS.keys()) elif subset == "blas": - clibs = ("openblas", "mkl", "mkl_win32") + clibs = ["openblas", "mkl", "mkl_win32"] elif subset == "openmp": - clibs = (c for c in _CLibsWrapper.SUPPORTED_CLIBS if "openmp" in c) + clibs = list(c for c in _CLibsWrapper.SUPPORTED_CLIBS if "openmp" in c) + + if SKIP_OPENBLAS and "openblas" in clibs: + clibs.remove("openblas") old_limits = get_thread_limits() @@ -56,6 +75,8 @@ def test_limit_threads_clib_subset(subset): def test_limit_threads_clib_bad_input(): + # Check that appropriate errors are raised for invalid arguments + with pytest.raises(ValueError, match="subset must be either 'all', 'blas' " "or 'openmp'"): From 6fbe8b19feddb62dc100837f85e0ef65c9723d75 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 14 Dec 2018 13:59:16 +0100 Subject: [PATCH 028/163] cython directive language_level --- sklearn/cluster/_k_means.pxd | 3 +++ sklearn/cluster/_k_means_elkan.pyx | 1 + sklearn/cluster/_k_means_lloyd.pyx | 1 + 3 files changed, 5 insertions(+) diff --git a/sklearn/cluster/_k_means.pxd b/sklearn/cluster/_k_means.pxd index 7edb31597b532..d2255b4d49363 100644 --- a/sklearn/cluster/_k_means.pxd +++ b/sklearn/cluster/_k_means.pxd @@ -1,3 +1,6 @@ +# cython: language_level=3 + + from cython cimport floating cimport numpy as np diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 20d38fb1dc4dc..efdb104ade822 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -1,4 +1,5 @@ # cython: profile=True, boundscheck=False, wraparound=False, cdivision=True +# cython: language_level=3 # # Author: Andreas Mueller # diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index a4bb0f61b60a9..f75490756c61b 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -1,4 +1,5 @@ # cython: profile=True, boundscheck=False, wraparound=False, cdivision=True +# cython: language_level=3 # # Licence: BSD 3 clause From bcb727e46a08880edf767f5b2fdeb10af69b8926 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 14 Dec 2018 14:07:41 +0100 Subject: [PATCH 029/163] fix merge conflicts --- .travis.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.travis.yml b/.travis.yml index 31d903a36ac93..bff6c5acb6672 100644 --- a/.travis.yml +++ b/.travis.yml @@ -48,6 +48,7 @@ matrix: NUMPY_VERSION="1.11.0" SCIPY_VERSION="0.17.0" CYTHON_VERSION="*" PILLOW_VERSION="4.0.0" COVERAGE=true if: type != cron +<<<<<<< 8337ef1d2fffb8cdd6ab8a42e404ecf901c6f912 <<<<<<< 91f39bd55a48e487720306b974f29c7de46d2209 # Linux environment to test the latest available dependencies and MKL. ======= @@ -57,6 +58,8 @@ matrix: PILLOW_VERSION="4.0.0" COVERAGE=true SKLEARN_SITE_JOBLIB=1 JOBLIB_VERSION="0.11" if: type != cron +======= +>>>>>>> fix merge conflicts # This environment tests the latest available dependencies. >>>>>>> tst build # It runs tests requiring pandas and PyAMG. From e4c159cf325b56913214e13513f1e091e08b0178 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 14 Dec 2018 14:08:17 +0100 Subject: [PATCH 030/163] fix merge conflicts --- .circleci/config.yml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index 550c1219f2fba..cb4224f937e2b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -137,4 +137,33 @@ workflows: >>>>>>> drop python 2 CI ======= fi +<<<<<<< 5c998549df27cfb3dd21d7ce54b04229c24bd5cc >>>>>>> fix merge conflict +======= + +workflows: + version: 2 + build-doc-and-deploy: + jobs: + - doc + - doc-min-dependencies + - lint + - pypy3: + filters: + branches: + only: + - 0.20.X + - deploy: + requires: + - python3 + pypy: + triggers: + - schedule: + cron: "0 0 * * *" + filters: + branches: + only: + - master + jobs: + - pypy3 +>>>>>>> fix merge conflicts From 684ea4e4dfe34d8d10a4f49774beba1299b85424 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 14 Dec 2018 14:52:51 +0100 Subject: [PATCH 031/163] thread limit context manager --- sklearn/cluster/k_means_.py | 33 +++++++++++------------- sklearn/utils/_clibs.py | 41 +++++++++++++++++++++++++++--- sklearn/utils/tests/test_clibs.py | 42 ++++++++++++++++++++----------- 3 files changed, 80 insertions(+), 36 deletions(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 5a10983004ce9..75756b39fd699 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -28,7 +28,7 @@ from ..utils import check_random_state from ..utils.validation import check_is_fitted from ..utils.validation import FLOAT_DTYPES -from ..utils._clibs import get_thread_limits, limit_threads_clibs +from ..utils._clibs import thread_limits_context from ..utils._joblib import effective_n_jobs from ..externals.six import string_types from ..exceptions import ConvergenceWarning @@ -373,24 +373,19 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', # limit number of threads in second level of nested parallelism (i.e. BLAS) # to avoid oversubsciption - limits = get_thread_limits() - limit_threads_clibs(limits=1, subset="blas") - - for seed in seeds: - # run a k-means once - labels, inertia, centers, n_iter_ = kmeans_single( - X, sample_weight, n_clusters, max_iter=max_iter, init=init, - verbose=verbose, tol=tol, x_squared_norms=x_squared_norms, - random_state=seed, n_jobs=n_jobs_) - # determine if these results are the best so far - if best_inertia is None or inertia < best_inertia: - best_labels = labels.copy() - best_centers = centers.copy() - best_inertia = inertia - best_n_iter = n_iter_ - - # release the limit on threads number and reset to initial value - limit_threads_clibs(limits=limits) + with thread_limits_context(limits=1, subset="blas"): + for seed in seeds: + # run a k-means once + labels, inertia, centers, n_iter_ = kmeans_single( + X, sample_weight, n_clusters, max_iter=max_iter, init=init, + verbose=verbose, tol=tol, x_squared_norms=x_squared_norms, + random_state=seed, n_jobs=n_jobs_) + # determine if these results are the best so far + if best_inertia is None or inertia < best_inertia: + best_labels = labels.copy() + best_centers = centers.copy() + best_inertia = inertia + best_n_iter = n_iter_ if not sp.issparse(X): if not copy_x: diff --git a/sklearn/utils/_clibs.py b/sklearn/utils/_clibs.py index 73be2f6aeb5df..0fab3924cf737 100644 --- a/sklearn/utils/_clibs.py +++ b/sklearn/utils/_clibs.py @@ -12,6 +12,7 @@ import threading import ctypes from ctypes.util import find_library +from contextlib import contextmanager as contextmanager # Structure to cast the info on dynamically loaded library. See @@ -66,7 +67,7 @@ def _unload(self): for clib, (module_name, _, _) in self.SUPPORTED_CLIBS.items(): delattr(self, clib) - def limit_threads_clibs(self, limits=1, subset=None): + def set_thread_limits(self, limits=1, subset=None): """Limit maximal number of threads used by supported C-libraries""" if isinstance(limits, int): if subset in ("all", None): @@ -293,7 +294,7 @@ def _get_wrapper(reload_clib=False): return _clibs_wrapper -def limit_threads_clibs(limits=1, subset=None, reload_clib=False): +def set_thread_limits(limits=1, subset=None, reload_clib=False): """Limit the number of threads available for threadpools in supported C-lib Set the maximal number of thread that can be used in thread pools used in @@ -333,7 +334,7 @@ def limit_threads_clibs(limits=1, subset=None, reload_clib=False): dynamically. """ wrapper = _get_wrapper(reload_clib) - return wrapper.limit_threads_clibs(limits, subset) + return wrapper.set_thread_limits(limits, subset) def get_thread_limits(reload_clib=True): @@ -357,6 +358,40 @@ def get_thread_limits(reload_clib=True): return wrapper.get_thread_limits() +@contextmanager +def thread_limits_context(limits=1, subset=None): + """Context manager for C-libs thread limits + + Parameters + ---------- + limits : int or dict, (default=1) + Maximum number of thread that can be used in thread pools + + If int, sets the maximum number of thread to `limits` for each C-lib + selected by `subset`. + + If dict(supported_libraries: max_threads), sets a custom maximum number + of thread for each C-lib. + + subset : string or None, optional (default="all") + Subset of C-libs to limit. Used only if `limits` is an int + + "all" : limit all supported C-libs. + + "blas" : limit only BLAS supported C-libs. + + "openmp" : limit only OpenMP supported C-libs. It can affect the number + of threads used by the BLAS C-libs if they rely on OpenMP. + """ + old_limits = get_thread_limits() + set_thread_limits(limits=limits, subset=subset) + + try: + yield + finally: + set_thread_limits(limits=old_limits) + + def get_openblas_version(reload_clib=True): """Return the OpenBLAS version diff --git a/sklearn/utils/tests/test_clibs.py b/sklearn/utils/tests/test_clibs.py index 17dd78c16ea92..ba75a6c51f83d 100644 --- a/sklearn/utils/tests/test_clibs.py +++ b/sklearn/utils/tests/test_clibs.py @@ -3,16 +3,16 @@ import pytest from sklearn.utils.testing import SkipTest -from sklearn.utils._clibs import get_thread_limits, limit_threads_clibs -from sklearn.utils._clibs import get_openblas_version -from sklearn.utils._clibs import _CLibsWrapper +from sklearn.utils._clibs import (get_thread_limits, set_thread_limits, + get_openblas_version, thread_limits_context, + _CLibsWrapper) SKIP_OPENBLAS = get_openblas_version() is None @pytest.mark.parametrize("clib", _CLibsWrapper.SUPPORTED_CLIBS) -def test_limit_threads_clib_dict(clib): +def test_set_thread_limits_dict(clib): # Check that the number of threads used by the multithreaded C-libs can be # modified dynamically. @@ -24,21 +24,21 @@ def test_limit_threads_clib_dict(clib): old_limits = get_thread_limits() if old_limits[clib] is not None: - dynamic_scaling = limit_threads_clibs(limits={clib: 1}) + dynamic_scaling = set_thread_limits(limits={clib: 1}) assert get_thread_limits()[clib] == 1 assert dynamic_scaling[clib] - limit_threads_clibs(limits={clib: 3}) + set_thread_limits(limits={clib: 3}) new_limits = get_thread_limits() assert new_limits[clib] in (3, os.cpu_count(), os.cpu_count() / 2) - limit_threads_clibs(limits=old_limits) + set_thread_limits(limits=old_limits) new_limits = get_thread_limits() assert new_limits[clib] == old_limits[clib] @pytest.mark.parametrize("subset", ("all", "blas", "openmp")) -def test_limit_threads_clib_subset(subset): +def test_set_thread_limits_subset(subset): # Check that the number of threads used by the multithreaded C-libs can be # modified dynamically. @@ -54,34 +54,48 @@ def test_limit_threads_clib_subset(subset): old_limits = get_thread_limits() - dynamic_scaling = limit_threads_clibs(limits=1, subset=subset) + dynamic_scaling = set_thread_limits(limits=1, subset=subset) new_limits = get_thread_limits() for clib in clibs: if old_limits[clib] is not None: assert new_limits[clib] == 1 assert dynamic_scaling[clib] - limit_threads_clibs(limits=3, subset=subset) + set_thread_limits(limits=3, subset=subset) new_limits = get_thread_limits() for clib in clibs: if old_limits[clib] is not None: assert new_limits[clib] in (3, os.cpu_count(), os.cpu_count() / 2) - limit_threads_clibs(limits=old_limits) + set_thread_limits(limits=old_limits) new_limits = get_thread_limits() for clib in clibs: if old_limits[clib] is not None: assert new_limits[clib] == old_limits[clib] -def test_limit_threads_clib_bad_input(): +def test_set_thread_limits_bad_input(): # Check that appropriate errors are raised for invalid arguments with pytest.raises(ValueError, match="subset must be either 'all', 'blas' " "or 'openmp'"): - limit_threads_clibs(limits=1, subset="wrong") + set_thread_limits(limits=1, subset="wrong") with pytest.raises(TypeError, match="limits must either be an int or a dict"): - limit_threads_clibs(limits=(1, 2, 3)) + set_thread_limits(limits=(1, 2, 3)) + + +def test_thread_limit_context(): + old_limits = get_thread_limits() + + with thread_limits_context(limits=1): + limits = get_thread_limits() + for clib in limits: + if old_limits[clib] is None: + assert limits[clib] is None + else: + assert limits[clib] == 1 + + assert get_thread_limits() == old_limits From 67900ad7f8d248e280277f3ea6a41b56be7d9ee1 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 14 Dec 2018 15:37:32 +0100 Subject: [PATCH 032/163] skip openblas --- sklearn/utils/tests/test_clibs.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/utils/tests/test_clibs.py b/sklearn/utils/tests/test_clibs.py index ba75a6c51f83d..c1e4376d00289 100644 --- a/sklearn/utils/tests/test_clibs.py +++ b/sklearn/utils/tests/test_clibs.py @@ -92,6 +92,9 @@ def test_thread_limit_context(): with thread_limits_context(limits=1): limits = get_thread_limits() + if SKIP_OPENBLAS: + del limits["openblas"] + for clib in limits: if old_limits[clib] is None: assert limits[clib] is None From c1d262f70dcad8350095a307cbf1c029fb3841a4 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 17 Dec 2018 09:34:12 +0100 Subject: [PATCH 033/163] new line end of file --- sklearn/cluster/_k_means.pxd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/cluster/_k_means.pxd b/sklearn/cluster/_k_means.pxd index d2255b4d49363..13b65491b8bae 100644 --- a/sklearn/cluster/_k_means.pxd +++ b/sklearn/cluster/_k_means.pxd @@ -30,4 +30,4 @@ cdef void _mean_and_center_shift( floating[:, ::1], floating[::1], floating[::1] -) \ No newline at end of file +) From ed308b755e1e66914af9a4c80ff0178b574e088b Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 21 Dec 2018 12:33:04 +0100 Subject: [PATCH 034/163] merge master CI --- .circleci/config.yml | 37 ------------------------------------- 1 file changed, 37 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index cb4224f937e2b..89cc103ec6301 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -103,8 +103,6 @@ jobs: command: | if [[ "${CIRCLE_BRANCH}" =~ ^master$|^[0-9]+\.[0-9]+\.X$ ]]; then bash build_tools/circle/push_doc.sh doc/_build/html/stable -<<<<<<< 16b3f13d9c9e44d9fb329800ddc37e42b000ffd6 -<<<<<<< 6c0faf614b525bad520269e28ec684b44c00c22a fi workflows: @@ -132,38 +130,3 @@ workflows: - master jobs: - pypy3 -======= - fi ->>>>>>> drop python 2 CI -======= - fi -<<<<<<< 5c998549df27cfb3dd21d7ce54b04229c24bd5cc ->>>>>>> fix merge conflict -======= - -workflows: - version: 2 - build-doc-and-deploy: - jobs: - - doc - - doc-min-dependencies - - lint - - pypy3: - filters: - branches: - only: - - 0.20.X - - deploy: - requires: - - python3 - pypy: - triggers: - - schedule: - cron: "0 0 * * *" - filters: - branches: - only: - - master - jobs: - - pypy3 ->>>>>>> fix merge conflicts From 4b76694b60678d9d254e0ef633c2e17a566062f3 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 21 Dec 2018 12:35:17 +0100 Subject: [PATCH 035/163] merge master CI --- .travis.yml | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/.travis.yml b/.travis.yml index bff6c5acb6672..2926f2df560ba 100644 --- a/.travis.yml +++ b/.travis.yml @@ -48,20 +48,7 @@ matrix: NUMPY_VERSION="1.11.0" SCIPY_VERSION="0.17.0" CYTHON_VERSION="*" PILLOW_VERSION="4.0.0" COVERAGE=true if: type != cron -<<<<<<< 8337ef1d2fffb8cdd6ab8a42e404ecf901c6f912 -<<<<<<< 91f39bd55a48e487720306b974f29c7de46d2209 # Linux environment to test the latest available dependencies and MKL. -======= - # Python 3.5 build tst - - env: DISTRIB="conda" PYTHON_VERSION="3.5" INSTALL_MKL="false" - NUMPY_VERSION="1.14" SCIPY_VERSION="1.0" CYTHON_VERSION="0.25.2" - PILLOW_VERSION="4.0.0" COVERAGE=true - SKLEARN_SITE_JOBLIB=1 JOBLIB_VERSION="0.11" - if: type != cron -======= ->>>>>>> fix merge conflicts - # This environment tests the latest available dependencies. ->>>>>>> tst build # It runs tests requiring pandas and PyAMG. # It also runs with the site joblib instead of the vendored copy of joblib. - env: DISTRIB="conda" PYTHON_VERSION="*" INSTALL_MKL="true" From e77ac243a3f99e363aae0c656819e706bec8cfde Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 21 Dec 2018 13:54:51 +0100 Subject: [PATCH 036/163] tst clang version --- build_tools/travis/install.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/build_tools/travis/install.sh b/build_tools/travis/install.sh index d79f8845a3d89..5cab753976eed 100755 --- a/build_tools/travis/install.sh +++ b/build_tools/travis/install.sh @@ -27,6 +27,11 @@ then ccache --max-size 100M --show-stats fi +if [ $TRAVIS_OS_NAME = "osx" ] +then + which gcc + gcc --version + make_conda() { TO_INSTALL="$@" # Deactivate the travis-provided virtual environment and setup a From 9ccc7250cee63a562f33b63536cf8b3d2f18c2ed Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 21 Dec 2018 14:00:45 +0100 Subject: [PATCH 037/163] same --- build_tools/travis/install.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/build_tools/travis/install.sh b/build_tools/travis/install.sh index 5cab753976eed..6a7093790e223 100755 --- a/build_tools/travis/install.sh +++ b/build_tools/travis/install.sh @@ -31,6 +31,7 @@ if [ $TRAVIS_OS_NAME = "osx" ] then which gcc gcc --version +fi make_conda() { TO_INSTALL="$@" From 9a03162ba93be9b630dc4aed9e1cf3c307212daa Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 21 Dec 2018 17:05:14 +0100 Subject: [PATCH 038/163] add llvm-openmp to travis --- build_tools/travis/install.sh | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/build_tools/travis/install.sh b/build_tools/travis/install.sh index 6a7093790e223..d0fb0409987d9 100755 --- a/build_tools/travis/install.sh +++ b/build_tools/travis/install.sh @@ -25,12 +25,13 @@ then # export CCACHE_LOGFILE=/tmp/ccache.log # ~60M is used by .ccache when compiling from scratch at the time of writing ccache --max-size 100M --show-stats -fi - -if [ $TRAVIS_OS_NAME = "osx" ] +elif [ $TRAVIS_OS_NAME = "osx" ] then - which gcc - gcc --version + # use clang installed by conda which supports OpenMP + export CC=clang + export CXX=clang + # avoid error due to multiple openmp libraries loaded simultaneously + export KMP_DUPLICATE_LIB_OK=TRUE fi make_conda() { @@ -44,6 +45,8 @@ make_conda() { if [ $TRAVIS_OS_NAME = "osx" ] then fname=Miniconda3-latest-MacOSX-x86_64.sh + # we need to install a version on clang which supports OpenMP + TO_INSTALL="$TO_INSTALL llvm-openmp clang" else fname=Miniconda3-latest-Linux-x86_64.sh fi From 215be9469f0fa8e1dc7e474c755688141755c0c8 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 14 Jan 2019 13:21:29 +0100 Subject: [PATCH 039/163] appveyor codecov --- appveyor.yml | 4 +++- build_tools/appveyor/requirements.txt | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index e5c4362451e97..10d7ed5eb761d 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -78,7 +78,7 @@ test_script: } else { $env:PYTEST_ARGS = "" } - - "pytest --showlocals --durations=20 %PYTEST_ARGS% --pyargs sklearn" + - "pytest --showlocals --durations=20 %PYTEST_ARGS% --pyargs --cov=sklearn sklearn" # Move back to the project folder - cd "../scikit-learn" @@ -87,6 +87,8 @@ artifacts: - path: dist\* on_success: + - "cp ../empty_folder/.coverage ." + - "codecov" # Upload the generated wheel package to Rackspace - "python -m wheelhouse_uploader upload --local-folder=dist sklearn-windows-wheels" diff --git a/build_tools/appveyor/requirements.txt b/build_tools/appveyor/requirements.txt index 1a2feca5c6b6b..40ddc39003e27 100644 --- a/build_tools/appveyor/requirements.txt +++ b/build_tools/appveyor/requirements.txt @@ -2,6 +2,9 @@ numpy scipy cython pytest +pytest-cov +coverage +codecov wheel wheelhouse_uploader pillow From bec907964f5bfb5a7a0c3f615f04c3f0247090f0 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 16 Jan 2019 10:35:32 +0100 Subject: [PATCH 040/163] openmp flags --- sklearn/cluster/setup.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/sklearn/cluster/setup.py b/sklearn/cluster/setup.py index fb83d38dd1e53..ca267de1dc713 100644 --- a/sklearn/cluster/setup.py +++ b/sklearn/cluster/setup.py @@ -1,10 +1,17 @@ # Author: Alexandre Gramfort # License: BSD 3 clause import os +import sys import numpy +def get_openmp_flag(): + if sys.platform == "win32": + return '/openmp' + return '-fopenmp' + + def configuration(parent_package='', top_path=None): from numpy.distutils.misc_util import Configuration @@ -33,15 +40,15 @@ def configuration(parent_package='', top_path=None): sources=['_k_means_lloyd.pyx'], include_dirs=[numpy.get_include()], libraries=libraries, - extra_link_args=['-fopenmp'], - extra_compile_args=['-fopenmp']) + extra_link_args=[get_openmp_flag()], + extra_compile_args=[get_openmp_flag()]) config.add_extension('_k_means_elkan', sources=['_k_means_elkan.pyx'], include_dirs=[numpy.get_include()], libraries=libraries, - extra_link_args=['-fopenmp'], - extra_compile_args=['-fopenmp']) + extra_link_args=[get_openmp_flag()], + extra_compile_args=[get_openmp_flag()]) config.add_subpackage('tests') From 43d3fba7c4d696793cc7c8ae5dbcf9b15dadd669 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 16 Jan 2019 14:57:33 +0100 Subject: [PATCH 041/163] openmp flags --- setup.py | 33 ++++++++++++++++++++++++++++++++- sklearn/cluster/setup.py | 15 ++------------- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/setup.py b/setup.py index cce21f5883c5a..984c8f90f4742 100755 --- a/setup.py +++ b/setup.py @@ -9,6 +9,7 @@ import platform import shutil from distutils.command.clean import clean as Clean +from numpy.distutils.command.build_ext import build_ext from pkg_resources import parse_version import traceback import builtins @@ -102,7 +103,37 @@ def run(self): shutil.rmtree(os.path.join(dirpath, dirname)) -cmdclass = {'clean': CleanCommand} +def get_openmp_flag(compiler): + if sys.platform == "win32" and compiler.startswith('ic'): + return ['/Qopenmp'] + elif sys.platform == "win32": + return ['/openmp'] + elif sys.platform == "darwin" and compiler.startswith('ic'): + return ['-openmp'] + return ['-fopenmp'] + + +OPENMP_EXTENSIONS = ["sklearn.cluster._k_means_lloyd", + "sklearn.cluster._k_means_elkan"] + + +# custom build_ext command to set OpenMP compile flags depending on os and +# compiler +class build_ext_subclass(build_ext): + def build_extensions(self): + compiler = self.compiler.compiler[0] + openmp_flag = get_openmp_flag(compiler) + + for e in self.extensions: + if e.name in OPENMP_EXTENSIONS: + e.extra_compile_args += openmp_flag + e.extra_link_args += openmp_flag + + build_ext.build_extensions(self) + + +cmdclass = {'clean': CleanCommand, 'build_ext': build_ext_subclass} + # Optional wheelhouse-uploader features # To automate release of binary packages for scikit-learn we need a tool diff --git a/sklearn/cluster/setup.py b/sklearn/cluster/setup.py index ca267de1dc713..7a1e419a34883 100644 --- a/sklearn/cluster/setup.py +++ b/sklearn/cluster/setup.py @@ -1,17 +1,10 @@ # Author: Alexandre Gramfort # License: BSD 3 clause import os -import sys import numpy -def get_openmp_flag(): - if sys.platform == "win32": - return '/openmp' - return '-fopenmp' - - def configuration(parent_package='', top_path=None): from numpy.distutils.misc_util import Configuration @@ -39,16 +32,12 @@ def configuration(parent_package='', top_path=None): config.add_extension('_k_means_lloyd', sources=['_k_means_lloyd.pyx'], include_dirs=[numpy.get_include()], - libraries=libraries, - extra_link_args=[get_openmp_flag()], - extra_compile_args=[get_openmp_flag()]) + libraries=libraries) config.add_extension('_k_means_elkan', sources=['_k_means_elkan.pyx'], include_dirs=[numpy.get_include()], - libraries=libraries, - extra_link_args=[get_openmp_flag()], - extra_compile_args=[get_openmp_flag()]) + libraries=libraries) config.add_subpackage('tests') From 2c613785efad8465d03afbe4d420d798a0b867f8 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 18 Jan 2019 11:21:59 +0100 Subject: [PATCH 042/163] openmp flags --- setup.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 984c8f90f4742..1e35c68579498 100755 --- a/setup.py +++ b/setup.py @@ -9,7 +9,6 @@ import platform import shutil from distutils.command.clean import clean as Clean -from numpy.distutils.command.build_ext import build_ext from pkg_resources import parse_version import traceback import builtins @@ -54,7 +53,7 @@ 'develop', 'release', 'bdist_egg', 'bdist_rpm', 'bdist_wininst', 'install_egg_info', 'build_sphinx', 'egg_info', 'easy_install', 'upload', 'bdist_wheel', - '--single-version-externally-managed', + '--single-version-externally-managed', 'build_ext' ]) if SETUPTOOLS_COMMANDS.intersection(sys.argv): import setuptools @@ -104,11 +103,11 @@ def run(self): def get_openmp_flag(compiler): - if sys.platform == "win32" and compiler.startswith('ic'): + if sys.platform == "win32" and ('icc' in compiler or 'icl' in compiler): return ['/Qopenmp'] elif sys.platform == "win32": return ['/openmp'] - elif sys.platform == "darwin" and compiler.startswith('ic'): + elif sys.platform == "darwin" and ('icc' in compiler or 'icl' in compiler): return ['-openmp'] return ['-fopenmp'] @@ -119,9 +118,17 @@ def get_openmp_flag(compiler): # custom build_ext command to set OpenMP compile flags depending on os and # compiler +# build_ext has to be imported after setuptools +from numpy.distutils.command.build_ext import build_ext + + class build_ext_subclass(build_ext): def build_extensions(self): - compiler = self.compiler.compiler[0] + if hasattr(self.compiler, 'compiler'): + compiler = self.compiler.compiler[0] + else: + compiler = self.compiler.__class__.__name__ + openmp_flag = get_openmp_flag(compiler) for e in self.extensions: From 7679a9fdf5af5ddb67c80fa7ffc77617630c21d1 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 18 Jan 2019 11:29:21 +0100 Subject: [PATCH 043/163] fix conflicts --- sklearn/cluster/k_means_.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 75756b39fd699..b37ebb001c4f9 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -30,7 +30,6 @@ from ..utils.validation import FLOAT_DTYPES from ..utils._clibs import thread_limits_context from ..utils._joblib import effective_n_jobs -from ..externals.six import string_types from ..exceptions import ConvergenceWarning from ._k_means import (_inertia_dense, _inertia_sparse, From f9160913fda109cdb6fe6aec421978737dace33d Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 18 Jan 2019 11:32:08 +0100 Subject: [PATCH 044/163] ompenmp --- setup.py | 5 ++++- sklearn/cluster/k_means_.py | 10 +++++----- sklearn/utils/tests/test_clibs.py | 10 ++++++++++ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 1e35c68579498..278f8b7fe1a82 100755 --- a/setup.py +++ b/setup.py @@ -109,6 +109,9 @@ def get_openmp_flag(compiler): return ['/openmp'] elif sys.platform == "darwin" and ('icc' in compiler or 'icl' in compiler): return ['-openmp'] + elif sys.platform == "darwin" and 'openmp' in os.getenv('CC', ''): + # -fopenmp can't be passed as compile arg when using apple clang + return [''] return ['-fopenmp'] @@ -119,7 +122,7 @@ def get_openmp_flag(compiler): # custom build_ext command to set OpenMP compile flags depending on os and # compiler # build_ext has to be imported after setuptools -from numpy.distutils.command.build_ext import build_ext +from numpy.distutils.command.build_ext import build_ext # noqa class build_ext_subclass(build_ext): diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index b37ebb001c4f9..a6093abac7040 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -826,15 +826,15 @@ class KMeans(BaseEstimator, ClusterMixin, TransformerMixin): >>> from sklearn.cluster import KMeans >>> import numpy as np >>> X = np.array([[1, 2], [1, 4], [1, 0], - ... [4, 2], [4, 4], [4, 0]]) + ... [10, 2], [10, 4], [10, 0]]) >>> kmeans = KMeans(n_clusters=2, random_state=1234).fit(X) >>> kmeans.labels_ - array([1, 1, 1, 0, 0, 0], dtype=int32) + array([0, 0, 0, 1, 1, 1], dtype=int32) >>> kmeans.predict([[0, 0], [12, 3]]) - array([1, 0], dtype=int32) + array([0, 1], dtype=int32) >>> kmeans.cluster_centers_ - array([[10., 2.], - [ 1., 2.]]) + array([[ 1., 2.], + [10., 2.]]) See also -------- diff --git a/sklearn/utils/tests/test_clibs.py b/sklearn/utils/tests/test_clibs.py index c1e4376d00289..823905ce0d2c3 100644 --- a/sklearn/utils/tests/test_clibs.py +++ b/sklearn/utils/tests/test_clibs.py @@ -11,6 +11,16 @@ SKIP_OPENBLAS = get_openblas_version() is None +def test_openmp_enabled(): + # Check that an OpenMP library is loaded + limits = get_thread_limits() + + assert not all([lib is None for lib in [limits['openmp_llvm'], + limits['openmp_gnu'], + limits['openmp_win32'], + limits['openmp_intel']]]) + + @pytest.mark.parametrize("clib", _CLibsWrapper.SUPPORTED_CLIBS) def test_set_thread_limits_dict(clib): # Check that the number of threads used by the multithreaded C-libs can be From 4c2da0c2c86ea36da1c04fa47644bbc71548654c Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 25 Jan 2019 18:41:47 +0100 Subject: [PATCH 045/163] no need --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 278f8b7fe1a82..be0efa00412a9 100755 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ 'develop', 'release', 'bdist_egg', 'bdist_rpm', 'bdist_wininst', 'install_egg_info', 'build_sphinx', 'egg_info', 'easy_install', 'upload', 'bdist_wheel', - '--single-version-externally-managed', 'build_ext' + '--single-version-externally-managed', ]) if SETUPTOOLS_COMMANDS.intersection(sys.argv): import setuptools From a4383fb240a75a4d582c21969d5a95904e2d370c Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 31 Jan 2019 14:10:38 +0100 Subject: [PATCH 046/163] flake8 --- sklearn/cluster/k_means_.py | 5 ++--- sklearn/utils/tests/test_clibs.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index a6093abac7040..a65e0b9c35203 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -302,9 +302,8 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', """ if precompute_distances != 'not-used': - warnings.warn("'precompute_distances' was deprecated in version" - "0.21 and will be removed in 0.23.", - DeprecationWarning) + warnings.warn("'precompute_distances' was deprecated in version" + "0.21 and will be removed in 0.23.", DeprecationWarning) if n_init <= 0: raise ValueError("Invalid number of initializations." diff --git a/sklearn/utils/tests/test_clibs.py b/sklearn/utils/tests/test_clibs.py index 823905ce0d2c3..43aad0d8666a8 100644 --- a/sklearn/utils/tests/test_clibs.py +++ b/sklearn/utils/tests/test_clibs.py @@ -26,7 +26,7 @@ def test_set_thread_limits_dict(clib): # Check that the number of threads used by the multithreaded C-libs can be # modified dynamically. - if clib is "openblas" and SKIP_OPENBLAS: + if clib == "openblas" and SKIP_OPENBLAS: raise SkipTest("Possible bug in getting maximum number of threads with" " OpenBLAS < 0.2.16 and OpenBLAS does not expose it's " "version before 0.3.4.") From cf253832ac655e5c105e192ee1821d11092f1567 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 31 Jan 2019 15:06:33 +0100 Subject: [PATCH 047/163] force init order --- sklearn/cluster/k_means_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index a65e0b9c35203..79de64485b18b 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -327,7 +327,7 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', # Validate init array if hasattr(init, '__array__'): - init = check_array(init, dtype=X.dtype.type, copy=True) + init = check_array(init, dtype=X.dtype.type, copy=True, order='C') _validate_center_shape(X, n_clusters, init) if n_init != 1: From 986863260fe4398e1516e547e4e164c510f6872b Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 31 Jan 2019 15:32:50 +0100 Subject: [PATCH 048/163] remove forced X order --- sklearn/cluster/_k_means_elkan.pyx | 4 ++-- sklearn/cluster/_k_means_lloyd.pyx | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index efdb104ade822..2acaec00aeb90 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -43,7 +43,7 @@ cdef floating euclidean_dist(floating* a, floating* b, int n_features) nogil: return sqrt(result) -cpdef _init_bounds(np.ndarray[floating, ndim=2, mode='c'] X, +cpdef _init_bounds(np.ndarray[floating, ndim=2] X, floating[:, ::1] centers, floating[:, ::1] center_half_distances, int[::1] labels, @@ -109,7 +109,7 @@ cpdef _init_bounds(np.ndarray[floating, ndim=2, mode='c'] X, upper_bounds[i] = min_dist -cpdef void _elkan_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, +cpdef void _elkan_iter_chunked_dense(np.ndarray[floating, ndim=2] X, floating[::1] sample_weight, floating[:, ::1] centers_old, floating[:, ::1] centers_new, diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index f75490756c61b..1eff39a0d4c50 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -31,7 +31,7 @@ cdef void xgemm(char *ta, char *tb, int *m, int *n, int *k, floating *alpha, dgemm(ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) -cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, +cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2] X, floating[::1] sample_weight, floating[::1] x_squared_norms, floating[:, ::1] centers_old, From 7326362ec6ab8af3144f7f8d6c73e7e226f51c98 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 31 Jan 2019 15:39:07 +0100 Subject: [PATCH 049/163] same --- sklearn/cluster/_k_means.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 04502bf36c88c..32baf85910476 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -22,7 +22,7 @@ ctypedef np.float64_t DOUBLE ctypedef np.int32_t INT -cpdef floating _inertia_dense(np.ndarray[floating, ndim=2, mode='c'] X, +cpdef floating _inertia_dense(np.ndarray[floating, ndim=2] X, floating[::1] sample_weight, floating[:, ::1] centers, int[::1] labels): @@ -88,7 +88,7 @@ cpdef floating _inertia_sparse(X, return inertia -cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] X, +cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2] X, floating[::1] sample_weight, floating[:, ::1] centers, floating[::1] weight_in_clusters, From 212ae77b72f973d8cee757f67681d40fc39ddfaf Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 31 Jan 2019 15:44:10 +0100 Subject: [PATCH 050/163] same --- sklearn/cluster/_k_means.pxd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/cluster/_k_means.pxd b/sklearn/cluster/_k_means.pxd index 13b65491b8bae..0a3593754b4ec 100644 --- a/sklearn/cluster/_k_means.pxd +++ b/sklearn/cluster/_k_means.pxd @@ -6,7 +6,7 @@ cimport numpy as np cdef void _relocate_empty_clusters_dense( - np.ndarray[floating, ndim=2, mode='c'], + np.ndarray[floating, ndim=2], floating[::1], floating[:, ::1], floating[::1], From e9a4ceebf5493032ff0744d41b9fab628669bdec Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Sun, 3 Feb 2019 19:45:44 +0100 Subject: [PATCH 051/163] directly use _cython_blas --- sklearn/cluster/_k_means_lloyd.pyx | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index 1eff39a0d4c50..939ad65d92e66 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -8,12 +8,13 @@ cimport numpy as np cimport openmp from cython cimport floating from cython.parallel import prange, parallel -from scipy.linalg.cython_blas cimport sgemm, dgemm from libc.math cimport sqrt from libc.stdlib cimport calloc, free from libc.string cimport memset, memcpy from libc.float cimport DBL_MAX, FLT_MAX +from ..utils._cython_blas cimport _gemm +from ..utils._cython_blas cimport RowMajor, Trans, NoTrans from ._k_means cimport (_relocate_empty_clusters_dense, _relocate_empty_clusters_sparse, _mean_and_center_shift) @@ -22,15 +23,6 @@ from ._k_means cimport (_relocate_empty_clusters_dense, np.import_array() -cdef void xgemm(char *ta, char *tb, int *m, int *n, int *k, floating *alpha, - floating *A, int *lda, floating *B, int *ldb, floating *beta, - floating *C, int *ldc) nogil: - if floating is float: - sgemm(ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) - else: - dgemm(ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) - - cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2] X, floating[::1] sample_weight, floating[::1] x_squared_norms, @@ -194,12 +186,6 @@ cdef void _update_chunk_dense(floating *X, cdef: floating sq_dist, min_sq_dist int i, j, k, best_cluster - - # parameters for the BLAS gemm - floating alpha = -2.0 - floating beta = 1.0 - char *trans_data = 'n' - char *trans_centers = 't' # Instead of computing the full pairwise squared distances matrix, # ||X - C||² = ||X||² - 2 X.C^T + ||C||², we only need to store @@ -209,9 +195,9 @@ cdef void _update_chunk_dense(floating *X, for j in range(n_clusters): pairwise_distances[i * n_clusters + j] = centers_squared_norms[j] - xgemm(trans_centers, trans_data, &n_clusters, &n_samples, &n_features, - &alpha, centers_old, &n_features, X, &n_features, - &beta, pairwise_distances, &n_clusters) + _gemm(RowMajor, NoTrans, Trans, n_samples, n_clusters, n_features, + -2.0, X, n_features, centers_old, n_features, + 1.0, pairwise_distances, n_clusters) for i in range(n_samples): min_sq_dist = pairwise_distances[i * n_clusters] From 41ea6dffb18b54e45ebf3422f4cf62b3bc6f4a59 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Sun, 3 Feb 2019 19:58:41 +0100 Subject: [PATCH 052/163] ensure order='C' even if copy_x = false --- sklearn/cluster/_k_means.pxd | 2 +- sklearn/cluster/_k_means.pyx | 4 ++-- sklearn/cluster/_k_means_elkan.pyx | 4 ++-- sklearn/cluster/_k_means_lloyd.pyx | 2 +- sklearn/cluster/k_means_.py | 14 ++++++-------- 5 files changed, 12 insertions(+), 14 deletions(-) diff --git a/sklearn/cluster/_k_means.pxd b/sklearn/cluster/_k_means.pxd index 0a3593754b4ec..13b65491b8bae 100644 --- a/sklearn/cluster/_k_means.pxd +++ b/sklearn/cluster/_k_means.pxd @@ -6,7 +6,7 @@ cimport numpy as np cdef void _relocate_empty_clusters_dense( - np.ndarray[floating, ndim=2], + np.ndarray[floating, ndim=2, mode='c'], floating[::1], floating[:, ::1], floating[::1], diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 32baf85910476..04502bf36c88c 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -22,7 +22,7 @@ ctypedef np.float64_t DOUBLE ctypedef np.int32_t INT -cpdef floating _inertia_dense(np.ndarray[floating, ndim=2] X, +cpdef floating _inertia_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[::1] sample_weight, floating[:, ::1] centers, int[::1] labels): @@ -88,7 +88,7 @@ cpdef floating _inertia_sparse(X, return inertia -cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2] X, +cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[::1] sample_weight, floating[:, ::1] centers, floating[::1] weight_in_clusters, diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 2acaec00aeb90..efdb104ade822 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -43,7 +43,7 @@ cdef floating euclidean_dist(floating* a, floating* b, int n_features) nogil: return sqrt(result) -cpdef _init_bounds(np.ndarray[floating, ndim=2] X, +cpdef _init_bounds(np.ndarray[floating, ndim=2, mode='c'] X, floating[:, ::1] centers, floating[:, ::1] center_half_distances, int[::1] labels, @@ -109,7 +109,7 @@ cpdef _init_bounds(np.ndarray[floating, ndim=2] X, upper_bounds[i] = min_dist -cpdef void _elkan_iter_chunked_dense(np.ndarray[floating, ndim=2] X, +cpdef void _elkan_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[::1] sample_weight, floating[:, ::1] centers_old, floating[:, ::1] centers_new, diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index 939ad65d92e66..661f6771e9a5e 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -23,7 +23,7 @@ from ._k_means cimport (_relocate_empty_clusters_dense, np.import_array() -cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2] X, +cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[::1] sample_weight, floating[::1] x_squared_norms, floating[:, ::1] centers_old, diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 161e75c0902ed..36112cfbe53d4 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -259,11 +259,11 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', copy_x : boolean, optional When pre-computing distances it is more numerically accurate to center the data first. If copy_x is True (default), then the original data is - not modified, ensuring X is C-contiguous. If False, the original data - is modified, and put back before the function returns, but small - numerical differences may be introduced by subtracting and then adding - the data mean, in this case it will also not ensure that data is - C-contiguous which may cause a significant slowdown. + not modified. If False, the original data is modified, and put back + before the function returns, but small numerical differences may be + introduced by subtracting and then adding the data mean. Note that if + the original data is not C-contiguous, a copy will be made even if + copy_x is False. n_jobs : int or None, optional (default=None) The number of jobs to use for the computation. This works by computing @@ -313,10 +313,8 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', raise ValueError('Number of iterations should be a positive number,' ' got %d instead' % max_iter) - # avoid forcing order when copy_x=False - order = "C" if copy_x else None X = check_array(X, accept_sparse='csr', dtype=[np.float64, np.float32], - order=order, copy=copy_x) + order='C', copy=copy_x) # verify that the number of samples given is larger than k if _num_samples(X) < n_clusters: raise ValueError("n_samples=%d should be >= n_clusters=%d" % ( From 745c75639d314ffba96001651c3af2e1e41aa8b5 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Sun, 3 Feb 2019 20:01:10 +0100 Subject: [PATCH 053/163] remove unnecessary condition --- sklearn/cluster/_k_means.pyx | 44 +++++++++++++++++------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 04502bf36c88c..4fb4f54a5e82b 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -111,22 +111,21 @@ cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] int new_cluster_id, old_cluster_id, far_idx, idx, k floating weight - if n_empty > 0: - for idx in range(n_empty): + for idx in range(n_empty): - new_cluster_id = empty_clusters[idx] + new_cluster_id = empty_clusters[idx] - far_idx = far_from_centers[idx] - weight = sample_weight[far_idx] + far_idx = far_from_centers[idx] + weight = sample_weight[far_idx] - old_cluster_id = labels[far_idx] + old_cluster_id = labels[far_idx] - for k in range(n_features): - centers[new_cluster_id, k] = X[far_idx, k] * weight - centers[old_cluster_id, k] -= X[far_idx, k] * weight + for k in range(n_features): + centers[new_cluster_id, k] = X[far_idx, k] * weight + centers[old_cluster_id, k] -= X[far_idx, k] * weight - weight_in_clusters[new_cluster_id] = weight - weight_in_clusters[old_cluster_id] -= weight + weight_in_clusters[new_cluster_id] = weight + weight_in_clusters[old_cluster_id] -= weight cdef void _relocate_empty_clusters_sparse(floating[::1] X_data, @@ -163,22 +162,21 @@ cdef void _relocate_empty_clusters_sparse(floating[::1] X_data, int new_cluster_id, old_cluster_id, far_idx, idx floating weight - if n_empty > 0: - for idx in range(n_empty): + for idx in range(n_empty): - new_cluster_id = empty_clusters[idx] + new_cluster_id = empty_clusters[idx] - far_idx = far_from_centers[idx] - weight = sample_weight[far_idx] + far_idx = far_from_centers[idx] + weight = sample_weight[far_idx] - old_cluster_id = labels[far_idx] - - for k in range(X_indptr[far_idx], X_indptr[far_idx + 1]): - centers[new_cluster_id, X_indices[k]] += X_data[k] * weight - centers[old_cluster_id, X_indices[k]] -= X_data[k] * weight + old_cluster_id = labels[far_idx] + + for k in range(X_indptr[far_idx], X_indptr[far_idx + 1]): + centers[new_cluster_id, X_indices[k]] += X_data[k] * weight + centers[old_cluster_id, X_indices[k]] -= X_data[k] * weight - weight_in_clusters[new_cluster_id] = weight - weight_in_clusters[old_cluster_id] -= weight + weight_in_clusters[new_cluster_id] = weight + weight_in_clusters[old_cluster_id] -= weight cdef void _mean_and_center_shift(floating[:, ::1] centers_old, From 930be8229b745d86b718088057c472b6d8b18a14 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 4 Feb 2019 11:53:34 +0100 Subject: [PATCH 054/163] merge master --- build_tools/travis/install.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/build_tools/travis/install.sh b/build_tools/travis/install.sh index a4f1734b3f90b..110a8661ed7c0 100755 --- a/build_tools/travis/install.sh +++ b/build_tools/travis/install.sh @@ -54,8 +54,6 @@ make_conda() { if [ $TRAVIS_OS_NAME = "osx" ] then fname=Miniconda3-latest-MacOSX-x86_64.sh - # we need to install a version on clang which supports OpenMP - TO_INSTALL="$TO_INSTALL llvm-openmp clang" else fname=Miniconda3-latest-Linux-x86_64.sh fi From 084db447bb960a5ccbfee6ea3c22814af518806a Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 6 Feb 2019 09:40:01 +0100 Subject: [PATCH 055/163] copy_x docstring --- sklearn/cluster/k_means_.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 36112cfbe53d4..00e42e1adf27d 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -780,11 +780,11 @@ class KMeans(BaseEstimator, ClusterMixin, TransformerMixin): copy_x : boolean, optional When pre-computing distances it is more numerically accurate to center the data first. If copy_x is True (default), then the original data is - not modified, ensuring X is C-contiguous. If False, the original data - is modified, and put back before the function returns, but small - numerical differences may be introduced by subtracting and then adding - the data mean, in this case it will also not ensure that data is - C-contiguous which may cause a significant slowdown. + not modified. If False, the original data is modified, and put back + before the function returns, but small numerical differences may be + introduced by subtracting and then adding the data mean. Note that if + the original data is not C-contiguous, a copy will be made even if + copy_x is False. n_jobs : int or None, optional (default=None) The number of jobs to use for the computation. This works by computing From 55a656340472c971213d26380fafc970ea7bb9b2 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Tue, 12 Feb 2019 18:21:07 +0100 Subject: [PATCH 056/163] refactor, use memviews more, add sparse elkan --- sklearn/cluster/_k_means.pxd | 41 +-- sklearn/cluster/_k_means.pyx | 109 ++++-- sklearn/cluster/_k_means_elkan.pyx | 468 ++++++++++++++++++++------ sklearn/cluster/_k_means_lloyd.pyx | 314 ++++++++--------- sklearn/cluster/k_means_.py | 70 ++-- sklearn/cluster/tests/test_k_means.py | 53 +-- 6 files changed, 667 insertions(+), 388 deletions(-) diff --git a/sklearn/cluster/_k_means.pxd b/sklearn/cluster/_k_means.pxd index 13b65491b8bae..a005250ad37e2 100644 --- a/sklearn/cluster/_k_means.pxd +++ b/sklearn/cluster/_k_means.pxd @@ -5,29 +5,18 @@ from cython cimport floating cimport numpy as np -cdef void _relocate_empty_clusters_dense( - np.ndarray[floating, ndim=2, mode='c'], - floating[::1], - floating[:, ::1], - floating[::1], - int[::1] -) - - -cdef void _relocate_empty_clusters_sparse( - floating[::1], - int[::1], - int[::1], - floating[::1], - floating[:, ::1], - floating[::1], - int[::1] -) - - -cdef void _mean_and_center_shift( - floating[:, ::1], - floating[:, ::1], - floating[::1], - floating[::1] -) +cdef floating _euclidean_dense_dense(floating*, floating*, int, bint) nogil + +cdef floating _euclidean_sparse_dense(floating[::1], int[::1], floating[::1], + floating, bint) nogil + +cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'], + floating[::1], floating[:, ::1], + floating[::1], int[::1]) + +cdef void _relocate_empty_clusters_sparse(floating[::1], int[::1], int[::1], + floating[::1], floating[:, ::1], + floating[::1], int[::1]) + +cdef void _mean_and_center_shift(floating[:, ::1], floating[:, ::1], + floating[::1], floating[::1]) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 4fb4f54a5e82b..600bda9256780 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -1,4 +1,4 @@ -# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True +# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False # Profiling is enabled by default as the overhead does not seem to be # measurable on this specific use case. @@ -14,6 +14,8 @@ cimport cython from cython cimport floating from libc.math cimport sqrt +from ..utils.extmath import row_norms + np.import_array() @@ -22,29 +24,76 @@ ctypedef np.float64_t DOUBLE ctypedef np.int32_t INT +cdef floating _euclidean_dense_dense(floating* a, + floating* b, + int n_features, + bint squared) nogil: + """Euclidean distance between a dense and b dense""" + cdef: + int i + int n = n_features // 4 + int rem = n_features % 4 + floating result = 0 + + for i in range(n): + result += ((a[0] - b[0]) * (a[0] - b[0]) + +(a[1] - b[1]) * (a[1] - b[1]) + +(a[2] - b[2]) * (a[2] - b[2]) + +(a[3] - b[3]) * (a[3] - b[3])) + a += 4; b += 4 + + for i in range(rem): + result += (a[i] - b[i]) * (a[i] - b[i]) + + if not squared: result = sqrt(result) + + return result + + +cdef floating _euclidean_sparse_dense(floating[::1] a_data, + int[::1] a_indices, + floating[::1] b, + floating b_squared_norm, + bint squared) nogil: + """Euclidean distance between a sparse and b dense""" + cdef: + int nnz = len(a_indices) + int i + floating tmp = 0.0 + floating result = 0.0 + + for i in range(nnz): + tmp = a_data[i] - b[a_indices[i]] + result += tmp * tmp - b[a_indices[i]] * b[a_indices[i]] + + result += b_squared_norm + + if not squared: result = sqrt(result) + + return result + + cpdef floating _inertia_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[::1] sample_weight, floating[:, ::1] centers, int[::1] labels): """Compute inertia for dense input data - Sum of squared distance between each sample and it's assigned center. + Sum of squared distance between each sample and its assigned center. """ cdef: int n_samples = X.shape[0] int n_features = X.shape[1] - int i, j, k - floating tmp, sample_inertia + int i, j + floating sq_dist = 0.0 floating inertia = 0.0 for i in range(n_samples): j = labels[i] - sample_inertia = 0.0 - for k in range(n_features): - tmp = X[i, k] - centers[j, k] - sample_inertia += tmp * tmp - inertia += sample_inertia * sample_weight[i] + sq_dist = _euclidean_dense_dense(&X[i, 0], ¢ers[j, 0], + n_features, True) + inertia += sq_dist * sample_weight[i] return inertia @@ -55,35 +104,29 @@ cpdef floating _inertia_sparse(X, int[::1] labels): """Compute inertia for sparse input data - Sum of squared distance between each sample and it's assigned center. + Sum of squared distance between each sample and its assigned center. """ cdef: floating[::1] X_data = X.data int[::1] X_indices = X.indices int[::1] X_indptr = X.indptr - int n_samples = X_indptr.shape[0] - 1 - int n_features = centers.shape[1] - int i, j, k - int row_ptr, nz_len, nz_ptr - floating tmp, sample_inertia + int n_samples = X.shape[0] + int n_features = X.shape[1] + int i, j + floating sq_dist = 0.0 floating inertia = 0.0 + + floating[::1] center_squared_norms = row_norms(centers, squared=True) for i in range(n_samples): j = labels[i] - sample_inertia = 0.0 - row_ptr = X_indptr[i] - nz_len = X_indptr[i + 1] - X_indptr[i] - nz_ptr = 0 - for k in range(n_features): - if nz_ptr < nz_len and k == X_indices[row_ptr + nz_ptr]: - tmp = X_data[row_ptr + nz_ptr] - centers[j, k] - nz_ptr += 1 - else: - tmp = - centers[j, k] - sample_inertia += tmp * tmp - inertia += sample_inertia * sample_weight[i] + sq_dist = _euclidean_sparse_dense( + X_data[X_indptr[i]: X_indptr[i + 1]], + X_indices[X_indptr[i]: X_indptr[i + 1]], + centers[j], center_squared_norms[j], True) + inertia += sq_dist * sample_weight[i] return inertia @@ -93,9 +136,9 @@ cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] floating[:, ::1] centers, floating[::1] weight_in_clusters, int[::1] labels): - """Relocate centers which have no sample assigned to them""" + """Relocate centers which have no sample assigned to them.""" cdef: - int[::1] empty_clusters = np.where(np.equal(weight_in_clusters,0))[0].astype(np.int32) + int[::1] empty_clusters = np.where(np.equal(weight_in_clusters, 0))[0].astype(np.int32) int n_empty = empty_clusters.shape[0] if n_empty == 0: @@ -135,14 +178,14 @@ cdef void _relocate_empty_clusters_sparse(floating[::1] X_data, floating[:, ::1] centers, floating[::1] weight_in_clusters, int[::1] labels): - """Relocate centers which have no sample assigned to them""" + """Relocate centers which have no sample assigned to them.""" cdef: - int[::1] empty_clusters = np.where(np.equal(weight_in_clusters,0))[0].astype(np.int32) + int[::1] empty_clusters = np.where(np.equal(weight_in_clusters, 0))[0].astype(np.int32) int n_empty = empty_clusters.shape[0] if n_empty == 0: return - + cdef: int n_samples = X_indptr.shape[0] - 1 floating x @@ -183,7 +226,7 @@ cdef void _mean_and_center_shift(floating[:, ::1] centers_old, floating[:, ::1] centers_new, floating[::1] weight_in_clusters, floating[::1] center_shift): - """Average new centers wrt weights and compute center shift""" + """Average new centers wrt weights and compute center shift.""" cdef: int n_clusters = centers_old.shape[0] int n_features = centers_old.shape[1] diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index efdb104ade822..4318a82842c88 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -1,4 +1,4 @@ -# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True +# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False # cython: language_level=3 # # Author: Andreas Mueller @@ -15,40 +15,23 @@ from libc.math cimport sqrt from libc.stdlib cimport calloc, free from libc.string cimport memset, memcpy -from ..metrics import euclidean_distances -from ._k_means cimport _relocate_empty_clusters_dense, _mean_and_center_shift +from ..utils.extmath import row_norms +from ._k_means cimport _relocate_empty_clusters_dense +from ._k_means cimport _relocate_empty_clusters_sparse +from ._k_means cimport _mean_and_center_shift +from ._k_means cimport _euclidean_dense_dense +from ._k_means cimport _euclidean_sparse_dense np.import_array() -cdef floating euclidean_dist(floating* a, floating* b, int n_features) nogil: - """Euclidean distance between a and b, optimized for vectorization""" - cdef: - int i - int n = n_features // 4 - int rem = n_features % 4 - floating result = 0 - - for i in range(n): - result += ((a[0] - b[0]) * (a[0] - b[0]) - +(a[1] - b[1]) * (a[1] - b[1]) - +(a[2] - b[2]) * (a[2] - b[2]) - +(a[3] - b[3]) * (a[3] - b[3])) - a += 4; b += 4 - - for i in range(rem): - result += (a[i] - b[i]) * (a[i] - b[i]) - - return sqrt(result) - - -cpdef _init_bounds(np.ndarray[floating, ndim=2, mode='c'] X, - floating[:, ::1] centers, - floating[:, ::1] center_half_distances, - int[::1] labels, - floating[::1] upper_bounds, - floating[:, ::1] lower_bounds): +cpdef _init_bounds_dense(np.ndarray[floating, ndim=2, mode='c'] X, + floating[:, ::1] centers, + floating[:, ::1] center_half_distances, + int[::1] labels, + floating[::1] upper_bounds, + floating[:, ::1] lower_bounds): """Initialize upper and lower bounds for each sample. Given X, centers and the pairwise distances divided by 2.0 between the @@ -96,11 +79,55 @@ cpdef _init_bounds(np.ndarray[floating, ndim=2, mode='c'] X, for i in range(n_samples): best_cluster = 0 - min_dist = euclidean_dist(&X[i, 0], ¢ers[0, 0], n_features) + min_dist = _euclidean_dense_dense(&X[i, 0], ¢ers[0, 0], + n_features, False) lower_bounds[i, 0] = min_dist for j in range(1, n_clusters): if min_dist > center_half_distances[best_cluster, j]: - dist = euclidean_dist(&X[i, 0], ¢ers[j, 0], n_features) + dist = _euclidean_dense_dense(&X[i, 0], ¢ers[j, 0], + n_features, False) + lower_bounds[i, j] = dist + if dist < min_dist: + min_dist = dist + best_cluster = j + labels[i] = best_cluster + upper_bounds[i] = min_dist + + +cpdef _init_bounds_sparse(X, + floating[:, ::1] centers, + floating[:, ::1] center_half_distances, + int[::1] labels, + floating[::1] upper_bounds, + floating[:, ::1] lower_bounds): + cdef: + int n_samples = X.shape[0] + int n_clusters = centers.shape[0] + int n_features = X.shape[1] + + floating[::1] X_data = X.data + int[::1] X_indices = X.indices + int[::1] X_indptr = X.indptr + + floating min_dist, dist + int best_cluster, i, j + + floating[::1] centers_squared_norms = row_norms(centers, squared=True) + + for i in range(n_samples): + best_cluster = 0 + min_dist = _euclidean_sparse_dense( + X_data[X_indptr[i]: X_indptr[i + 1]], + X_indices[X_indptr[i]: X_indptr[i + 1]], + centers[0], centers_squared_norms[0], False) + + lower_bounds[i, 0] = min_dist + for j in range(1, n_clusters): + if min_dist > center_half_distances[best_cluster, j]: + dist = _euclidean_sparse_dense( + X_data[X_indptr[i]: X_indptr[i + 1]], + X_indices[X_indptr[i]: X_indptr[i + 1]], + centers[j], centers_squared_norms[j], False) lower_bounds[i, j] = dist if dist < min_dist: min_dist = dist @@ -120,9 +147,9 @@ cpdef void _elkan_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[:, ::1] lower_bounds, int[::1] labels, floating[::1] center_shift, - int n_jobs = -1, - bint update_centers = True): - """Single interation of K-means elkan algorithm + int n_jobs=-1, + bint update_centers=True): + """Single iteration of K-means elkan algorithm Update labels and centers (inplace), for one iteration, distributed over data chunks. @@ -153,7 +180,7 @@ shape (n_clusters, n_clusters) distance_next_center : {float32, float64} array-like, shape (n_clusters,) Distance between each center it's closest center. - + upper_bounds : {float32, float64} array-like, shape (n_samples,) Upper bound for the distance between each sample and it's center, updated inplace. @@ -189,17 +216,15 @@ shape (n_clusters, n_clusters) int n_chunks = n_samples // n_samples_chunk int n_samples_r = n_samples % n_samples_chunk int chunk_idx, n_samples_chunk_eff + int start, end int num_threads int i, j, k - int label - floating alpha, tmp, x - - floating *centers_new_chunk - floating *weight_in_clusters_chunk - # count remainder chunk in total number of chunks - n_chunks += n_samples != n_chunks * n_samples_chunk + # If n_samples < 256 there's still one chunk of size n_samples_r + if n_chunks == 0: + n_chunks = 1 + n_samples_chunk = 0 # re-initialize all arrays at each iteration if update_centers: @@ -211,50 +236,262 @@ shape (n_clusters, n_clusters) num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() with nogil, parallel(num_threads=num_threads): - # thread local buffers - centers_new_chunk = calloc(n_clusters * n_features, sizeof(floating)) - weight_in_clusters_chunk = calloc(n_clusters, sizeof(floating)) - + for chunk_idx in prange(n_chunks): - if n_samples_r > 0 and chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_r + # remaining samples added to last chunk + if chunk_idx == n_chunks - 1: + n_samples_chunk_eff = n_samples_chunk + n_samples_r else: n_samples_chunk_eff = n_samples_chunk - _update_chunk( - &X[chunk_idx * n_samples_chunk, 0], - &sample_weight[chunk_idx * n_samples_chunk], - ¢ers_old[0, 0], - centers_new_chunk, - ¢er_half_distances[0, 0], - &distance_next_center[0], - weight_in_clusters_chunk, - &labels[chunk_idx * n_samples_chunk], - &upper_bounds[chunk_idx * n_samples_chunk], - &lower_bounds[chunk_idx * n_samples_chunk, 0], - n_samples_chunk_eff, - n_clusters, - n_features, + start = chunk_idx * n_samples_chunk + end = start + n_samples_chunk_eff + + _update_chunk_dense( + &X[start, 0], + sample_weight[start: end], + centers_old, + centers_new, + center_half_distances, + distance_next_center, + weight_in_clusters, + labels[start: end], + upper_bounds[start: end], + lower_bounds[start: end], update_centers) - # reduction from local buffers. The gil is necessary for that to avoid - # race conditions. - if update_centers: - with gil: - for j in range(n_clusters): - weight_in_clusters[j] += weight_in_clusters_chunk[j] - for k in range(n_features): - centers_new[j, k] += centers_new_chunk[j * n_features + k] + if update_centers: + _relocate_empty_clusters_dense( + X, sample_weight, centers_new, weight_in_clusters, labels) + + _mean_and_center_shift( + centers_old, centers_new, weight_in_clusters, center_shift) + + # update lower and upper bounds + for i in range(n_samples): + upper_bounds[i] += center_shift[labels[i]] + + for j in range(n_clusters): + lower_bounds[i, j] -= center_shift[j] + if lower_bounds[i, j] < 0: + lower_bounds[i, j] = 0 + + +cdef void _update_chunk_dense(floating *X, + floating[::1] sample_weight, + floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[:, ::1] center_half_distances, + floating[::1] distance_next_center, + floating[::1] weight_in_clusters, + int[::1] labels, + floating[::1] upper_bounds, + floating[:, ::1] lower_bounds, + bint update_centers) nogil: + """K-means combined EM step for one data chunk + + Compute the partial contribution of a single data chunk to the labels and + centers. + """ + cdef: + int n_samples = labels.shape[0] + int n_clusters = centers_old.shape[0] + int n_features = centers_old.shape[1] + + floating upper_bound, distance + int i, j, k, label + + for i in range(n_samples): + upper_bound = upper_bounds[i] + bounds_tight = 0 + label = labels[i] + + # Next center is not far away from the currently assigned center. + # Sample might need to be assigned to another center. + if not distance_next_center[label] >= upper_bound: + + for j in range(n_clusters): + + # If this holds, then center_index is a good candidate for the + # sample to be relabelled, and we need to confirm this by + # recomputing the upper and lower bounds. + if (j != label + and (upper_bound > lower_bounds[i, j]) + and (upper_bound > center_half_distances[label, j])): + + # Recompute upper bound by calculating the actual distance + # between the sample and it's current assigned center. + if not bounds_tight: + upper_bound = _euclidean_dense_dense( + X + i * n_features, ¢ers_old[label, 0], n_features, False) + lower_bounds[i, label] = upper_bound + bounds_tight = 1 + + # If the condition still holds, then compute the actual + # distance between the sample and center. If this is less + # than the previous distance, reassign label. + if (upper_bound > lower_bounds[i, j] + or (upper_bound > center_half_distances[label, j])): + + distance = _euclidean_dense_dense( + X + i * n_features, ¢ers_old[j, 0], n_features, False) + lower_bounds[i, j] = distance + if distance < upper_bound: + label = j + upper_bound = distance + + labels[i] = label + upper_bounds[i] = upper_bound + + if update_centers: + # The gil is necessary for that to avoid race conditions. + with gil: + for i in range(n_samples): + weight_in_clusters[labels[i]] += sample_weight[i] + for k in range(n_features): + centers_new[labels[i], k] += X[i * n_features + k] * sample_weight[i] + + +cpdef void _elkan_iter_chunked_sparse(X, + floating[::1] sample_weight, + floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[::1] weight_in_clusters, + floating[:, ::1] center_half_distances, + floating[::1] distance_next_center, + floating[::1] upper_bounds, + floating[:, ::1] lower_bounds, + int[::1] labels, + floating[::1] center_shift, + int n_jobs=-1, + bint update_centers=True): + """Single iteration of K-means elkan algorithm with sparse input + + Update labels and centers (inplace), for one iteration, distributed + over data chunks. + + Parameters + ---------- + X : {float32, float64} CSR matrix, shape (n_samples, n_features) + The observations to cluster. + + sample_weight : {float32, float64} array-like, shape (n_samples,) + The weights for each observation in X. + + centers_old : {float32, float64} array-like, shape (n_clusters, n_features) + Centers before previous iteration, placeholder for the centers after + previous iteration. + + centers_new : {float32, float64} array-like, shape (n_clusters, n_features) + Centers after previous iteration, placeholder for the new centers + computed during this iteration. + + weight_in_clusters : {float32, float64} array-like, shape (n_clusters,) + Placeholder for the sums of the weights of every observation assigned + to each center. + + center_half_distances : {float32, float64} array-like, \ +shape (n_clusters, n_clusters) + Half pairwise distances between centers. + + distance_next_center : {float32, float64} array-like, shape (n_clusters,) + Distance between each center it's closest center. + + upper_bounds : {float32, float64} array-like, shape (n_samples,) + Upper bound for the distance between each sample and it's center, + updated inplace. + + lower_bounds : {float32, float64} array-like, shape (n_samples, n_clusters) + Lower bound for the distance between each sample and each center, + updated inplace. + + labels : int array-like, shape (n_samples,) + labels assignment. + + center_shift : {float32, float64} array-like, shape (n_clusters,) + Distance between old and new centers. + + n_jobs : int + The number of threads to be used by openmp. If -1, openmp will use as + many as possible. - free(weight_in_clusters_chunk) - free(centers_new_chunk) + update_centers : bool + - If True, the labels and the new centers will be computed, i.e. runs + the E-step and the M-step of the algorithm. + - If False, only the labels will be computed, i.e runs the E-step of + the algorithm. + """ + cdef: + int n_samples = X.shape[0] + int n_features = X.shape[1] + int n_clusters = centers_new.shape[0] + + floating[::1] X_data = X.data + int[::1] X_indices = X.indices + int[::1] X_indptr = X.indptr + + # hard-coded number of samples per chunk. Splitting in chunks is + # necessary to get parallelism. Chunk size chosed to be same as lloyd's + int n_samples_chunk = 256 if n_samples > 256 else n_samples + int n_chunks = n_samples // n_samples_chunk + int n_samples_r = n_samples % n_samples_chunk + int chunk_idx, n_samples_chunk_eff + int start, end + int num_threads + + int i, j, k + + floating[::1] centers_squared_norms = row_norms(centers_new, squared=True) + # If n_samples < 256 there's still one chunk of size n_samples_r + if n_chunks == 0: + n_chunks = 1 + n_samples_chunk = 0 + + # re-initialize all arrays at each iteration if update_centers: - _relocate_empty_clusters_dense(X, sample_weight, centers_new, - weight_in_clusters, labels) + memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) + memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) + memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) - _mean_and_center_shift(centers_old, centers_new, weight_in_clusters, - center_shift) + # set number of threads to be used by openmp + num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() + + with nogil, parallel(num_threads=num_threads): + + for chunk_idx in prange(n_chunks): + # remaining samples added to last chunk + if chunk_idx == n_chunks - 1: + n_samples_chunk_eff = n_samples_chunk + n_samples_r + else: + n_samples_chunk_eff = n_samples_chunk + + start = chunk_idx * n_samples_chunk + end = start + n_samples_chunk_eff + + _update_chunk_sparse( + X_data[X_indptr[start]: X_indptr[end]], + X_indices[X_indptr[start]: X_indptr[end]], + X_indptr[start: end], + sample_weight[start: end], + centers_old, + centers_new, + centers_squared_norms, + center_half_distances, + distance_next_center, + weight_in_clusters, + labels[start: end], + upper_bounds[start: end], + lower_bounds[start: end], + update_centers) + + if update_centers: + _relocate_empty_clusters_sparse( + X_data, X_indices, X_indptr, sample_weight, + centers_new, weight_in_clusters, labels) + + _mean_and_center_shift( + centers_old, centers_new, weight_in_clusters, center_shift) # update lower and upper bounds for i in range(n_samples): @@ -266,28 +503,33 @@ shape (n_clusters, n_clusters) lower_bounds[i, j] = 0 -cdef void _update_chunk(floating *X, - floating *sample_weight, - floating *centers_old, - floating *centers_new, - floating *center_half_distances, - floating *distance_next_center, - floating *weight_in_clusters, - int *labels, - floating *upper_bounds, - floating *lower_bounds, - int n_samples, - int n_clusters, - int n_features, - bint update_centers) nogil: - """K-means step for one data chunk using elkan algorithm - +cdef void _update_chunk_sparse(floating[::1] X_data, + int[::1] X_indices, + int[::1] X_indptr, + floating[::1] sample_weight, + floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[::1] centers_squared_norms, + floating[:, ::1] center_half_distances, + floating[::1] distance_next_center, + floating[::1] weight_in_clusters, + int[::1] labels, + floating[::1] upper_bounds, + floating[:, ::1] lower_bounds, + bint update_centers) nogil: + """K-means combined EM step for one data chunk + Compute the partial contribution of a single data chunk to the labels and centers. """ cdef: + int n_samples = labels.shape[0] + int n_clusters = centers_old.shape[0] + int n_features = centers_old.shape[1] + floating upper_bound, distance int i, j, k, label + int s = X_indptr[0] for i in range(n_samples): upper_bound = upper_bounds[i] @@ -304,28 +546,29 @@ cdef void _update_chunk(floating *X, # sample to be relabelled, and we need to confirm this by # recomputing the upper and lower bounds. if (j != label - and (upper_bound > lower_bounds[i * n_clusters + j]) - and (upper_bound > center_half_distances[label * n_clusters + j])): + and (upper_bound > lower_bounds[i, j]) + and (upper_bound > center_half_distances[label, j])): # Recompute upper bound by calculating the actual distance # between the sample and it's current assigned center. if not bounds_tight: - upper_bound = euclidean_dist(X + i * n_features, - centers_old + label * n_features, - n_features) - lower_bounds[i * n_clusters + label] = upper_bound + upper_bound = _euclidean_sparse_dense( + X_data[X_indptr[i] - s: X_indptr[i + 1] -s], + X_indices[X_indptr[i] -s: X_indptr[i + 1] -s], + centers_old[label], centers_squared_norms[label], False) + lower_bounds[i, label] = upper_bound bounds_tight = 1 # If the condition still holds, then compute the actual # distance between the sample and center. If this is less - #than the previous distance, reassign label. - if (upper_bound > lower_bounds[i * n_clusters + j] - or (upper_bound > center_half_distances[label * n_clusters + j])): - - distance = euclidean_dist(X + i * n_features, - centers_old + j * n_features, - n_features) - lower_bounds[i * n_clusters + j] = distance + # than the previous distance, reassign label. + if (upper_bound > lower_bounds[i, j] + or (upper_bound > center_half_distances[label, j])): + distance = _euclidean_sparse_dense( + X_data[X_indptr[i] - s: X_indptr[i + 1] -s], + X_indices[X_indptr[i] -s: X_indptr[i + 1] -s], + centers_old[j], centers_squared_norms[j], False) + lower_bounds[i, j] = distance if distance < upper_bound: label = j upper_bound = distance @@ -333,7 +576,10 @@ cdef void _update_chunk(floating *X, labels[i] = label upper_bounds[i] = upper_bound - if update_centers: - weight_in_clusters[label] += sample_weight[i] - for k in range(n_features): - centers_new[label * n_features + k] += X[i * n_features + k] * sample_weight[i] \ No newline at end of file + if update_centers: + # The gil is necessary for that to avoid race conditions. + with gil: + for i in range(n_samples): + weight_in_clusters[labels[i]] += sample_weight[i] + for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): + centers_new[labels[i], X_indices[k]] += X_data[k] * sample_weight[i] diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index 661f6771e9a5e..e9e44ded79a5c 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -1,4 +1,4 @@ -# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True +# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False # cython: language_level=3 # # Licence: BSD 3 clause @@ -9,15 +9,16 @@ cimport openmp from cython cimport floating from cython.parallel import prange, parallel from libc.math cimport sqrt -from libc.stdlib cimport calloc, free +from libc.stdlib cimport malloc, free from libc.string cimport memset, memcpy from libc.float cimport DBL_MAX, FLT_MAX +from ..utils.extmath import row_norms from ..utils._cython_blas cimport _gemm from ..utils._cython_blas cimport RowMajor, Trans, NoTrans -from ._k_means cimport (_relocate_empty_clusters_dense, - _relocate_empty_clusters_sparse, - _mean_and_center_shift) +from ._k_means cimport _relocate_empty_clusters_dense +from ._k_means cimport _relocate_empty_clusters_sparse +from ._k_means cimport _mean_and_center_shift np.import_array() @@ -29,12 +30,12 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[:, ::1] centers_old, floating[:, ::1] centers_new, floating[::1] centers_squared_norms, - floating[::1] weight_in_clusters, + floating[::1] weight_in_clusters, int[::1] labels, floating[::1] center_shift, - int n_jobs = -1, - bint update_centers = True): - """Single interation of K-means lloyd algorithm + int n_jobs=-1, + bint update_centers=True): + """Single iteration of K-means lloyd algorithm Update labels and centers (inplace), for one iteration, distributed over data chunks. @@ -49,7 +50,7 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, x_squared_norms : {float32, float64} array-like, shape (n_samples,) Squared L2 norm of X. - + centers_old : {float32, float64} array-like, shape (n_clusters, n_features) Centers before previous iteration, placeholder for the centers after previous iteration. @@ -57,7 +58,7 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, centers_new : {float32, float64} array-like, shape (n_clusters, n_features) Centers after previous iteration, placeholder for the new centers computed during this iteration. - + centers_squared_norms : {float32, float64} array-like, shape (n_clusters,) Squared L2 norm of the centers. @@ -67,7 +68,7 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, labels : int array-like, shape (n_samples,) labels assignment. - + center_shift : {float32, float64} array-like, shape (n_clusters,) Distance between old and new centers. @@ -92,23 +93,18 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, int n_chunks = n_samples // n_samples_chunk int n_samples_r = n_samples % n_samples_chunk int chunk_idx, n_samples_chunk_eff + int start, end int num_threads int j, k - floating alpha - floating *centers_new_chunk - floating *weight_in_clusters_chunk - floating *pairwise_distances_chunk + # If n_samples < 256 there's still one chunk of size n_samples_r + if n_chunks == 0: + n_chunks = 1 + n_samples_chunk = 0 - # count remainder chunk in total number of chunks - n_chunks += n_samples != n_chunks * n_samples_chunk - # re-initialize all arrays at each iteration - memset(¢ers_squared_norms[0], 0, n_clusters * sizeof(floating)) - for j in range(n_clusters): - for k in range(n_features): - centers_squared_norms[j] += centers_new[j, k] * centers_new[j, k] + centers_squared_norms = row_norms(centers_new, squared=True) if update_centers: memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) @@ -117,75 +113,65 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, # set number of threads to be used by openmp num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() + with nogil, parallel(num_threads=num_threads): - # thread local buffers - centers_new_chunk = calloc(n_clusters * n_features, sizeof(floating)) - weight_in_clusters_chunk = calloc(n_clusters, sizeof(floating)) - pairwise_distances_chunk = calloc(n_samples_chunk * n_clusters, sizeof(floating)) - + for chunk_idx in prange(n_chunks): - if n_samples_r > 0 and chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_r + # remaining samples added to last chunk + if chunk_idx == n_chunks - 1: + n_samples_chunk_eff = n_samples_chunk + n_samples_r else: n_samples_chunk_eff = n_samples_chunk + start = chunk_idx * n_samples_chunk + end = start + n_samples_chunk_eff + _update_chunk_dense( - &X[chunk_idx * n_samples_chunk, 0], - &sample_weight[chunk_idx * n_samples_chunk], - &x_squared_norms[chunk_idx * n_samples_chunk], - ¢ers_old[0, 0], - centers_new_chunk, - ¢ers_squared_norms[0], - weight_in_clusters_chunk, - pairwise_distances_chunk, - &labels[chunk_idx * n_samples_chunk], - n_samples_chunk_eff, - n_clusters, - n_features, + &X[start, 0], + sample_weight[start: end], + x_squared_norms[start: end], + centers_old, + centers_new, + centers_squared_norms, + weight_in_clusters, + labels[start: end], update_centers) - # reduction from local buffers. The gil is necessary for that to avoid - # race conditions. - if update_centers: - with gil: - for j in range(n_clusters): - weight_in_clusters[j] += weight_in_clusters_chunk[j] - for k in range(n_features): - centers_new[j, k] += centers_new_chunk[j * n_features + k] - - free(weight_in_clusters_chunk) - free(centers_new_chunk) - free(pairwise_distances_chunk) - if update_centers: - _relocate_empty_clusters_dense(X, sample_weight, centers_new, - weight_in_clusters, labels) + _relocate_empty_clusters_dense( + X, sample_weight, centers_new, weight_in_clusters, labels) - _mean_and_center_shift(centers_old, centers_new, weight_in_clusters, - center_shift) + _mean_and_center_shift( + centers_old, centers_new, weight_in_clusters, center_shift) cdef void _update_chunk_dense(floating *X, - floating *sample_weight, - floating *x_squared_norms, - floating *centers_old, - floating *centers_new, - floating *centers_squared_norms, - floating *weight_in_clusters, - floating *pairwise_distances, - int *labels, - int n_samples, - int n_clusters, - int n_features, + floating[::1] sample_weight, + floating[::1] x_squared_norms, + floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[::1] centers_squared_norms, + floating[::1] weight_in_clusters, + int[::1] labels, bint update_centers) nogil: """K-means combined EM step for one data chunk - + Compute the partial contribution of a single data chunk to the labels and centers. """ cdef: + int n_samples = labels.shape[0] + int n_clusters = centers_old.shape[0] + int n_features = centers_old.shape[1] + floating sq_dist, min_sq_dist - int i, j, k, best_cluster + int i, j, k, label + + floating *pairwise_distances_ptr = malloc(n_samples * n_clusters * sizeof(floating)) + floating[:, ::1] pairwise_distances + + with gil: + pairwise_distances = pairwise_distances_ptr # Instead of computing the full pairwise squared distances matrix, # ||X - C||² = ||X||² - 2 X.C^T + ||C||², we only need to store @@ -193,27 +179,31 @@ cdef void _update_chunk_dense(floating *X, # depends on the centers. for i in range(n_samples): for j in range(n_clusters): - pairwise_distances[i * n_clusters + j] = centers_squared_norms[j] - + pairwise_distances[i, j] = centers_squared_norms[j] + _gemm(RowMajor, NoTrans, Trans, n_samples, n_clusters, n_features, - -2.0, X, n_features, centers_old, n_features, - 1.0, pairwise_distances, n_clusters) + -2.0, X, n_features, ¢ers_old[0, 0], n_features, + 1.0, pairwise_distances_ptr, n_clusters) for i in range(n_samples): - min_sq_dist = pairwise_distances[i * n_clusters] - best_cluster = 0 - for j in range(n_clusters): - sq_dist = pairwise_distances[i * n_clusters + j] + min_sq_dist = pairwise_distances[i, 0] + label = 0 + for j in range(1, n_clusters): + sq_dist = pairwise_distances[i, j] if sq_dist < min_sq_dist: min_sq_dist = sq_dist - best_cluster = j + label = j + labels[i] = label - labels[i] = best_cluster + free(pairwise_distances_ptr) - if update_centers: - weight_in_clusters[best_cluster] += sample_weight[i] - for k in range(n_features): - centers_new[best_cluster * n_features + k] += X[i * n_features + k] * sample_weight[i] + if update_centers: + # The gil is necessary for that to avoid race conditions. + with gil: + for i in range(n_samples): + weight_in_clusters[labels[i]] += sample_weight[i] + for k in range(n_features): + centers_new[labels[i], k] += X[i * n_features + k] * sample_weight[i] cpdef void _lloyd_iter_chunked_sparse(X, @@ -222,12 +212,12 @@ cpdef void _lloyd_iter_chunked_sparse(X, floating[:, ::1] centers_old, floating[:, ::1] centers_new, floating[::1] centers_squared_norms, - floating[::1] weight_in_clusters, + floating[::1] weight_in_clusters, int[::1] labels, floating[::1] center_shift, - int n_jobs = -1, - bint update_centers = True): - """Single interation of K-means lloyd algorithm + int n_jobs=-1, + bint update_centers=True): + """Single iteration of K-means lloyd algorithm Update labels and centers (inplace), for one iteration, distributed over data chunks. @@ -242,7 +232,7 @@ cpdef void _lloyd_iter_chunked_sparse(X, x_squared_norms : {float32, float64} array-like, shape (n_samples,) Squared L2 norm of X. - + centers_old : {float32, float64} array-like, shape (n_clusters, n_features) Centers before previous iteration, placeholder for the centers after previous iteration. @@ -250,7 +240,7 @@ cpdef void _lloyd_iter_chunked_sparse(X, centers_new : {float32, float64} array-like, shape (n_clusters, n_features) Centers after previous iteration, placeholder for the new centers computed during this iteration. - + centers_squared_norms : {float32, float64} array-like, shape (n_clusters,) Squared L2 norm of the centers. @@ -260,7 +250,7 @@ cpdef void _lloyd_iter_chunked_sparse(X, labels : int array-like, shape (n_samples,) labels assignment. - + center_shift : {float32, float64} array-like, shape (n_clusters,) Distance between old and new centers. @@ -283,7 +273,8 @@ cpdef void _lloyd_iter_chunked_sparse(X, int n_samples_chunk = 256 if n_samples > 256 else n_samples int n_chunks = n_samples // n_samples_chunk int n_samples_r = n_samples % n_samples_chunk - int chunk_idx, n_samples_chunk_eff + int chunk_idx, n_samples_chunk_eff = 0 + int start = 0, end = 0 int num_threads int j, k @@ -293,17 +284,13 @@ cpdef void _lloyd_iter_chunked_sparse(X, int[::1] X_indices = X.indices int[::1] X_indptr = X.indptr - floating *centers_new_chunk - floating *weight_in_clusters_chunk + # If n_samples < 256 there's still one chunk of size n_samples_r + if n_chunks == 0: + n_chunks = 1 + n_samples_chunk = 0 - # count remainder for total number of chunks - n_chunks += n_samples != n_chunks * n_samples_chunk - # re-initialize all arrays at each iteration - memset(¢ers_squared_norms[0], 0, n_clusters * sizeof(floating)) - for j in range(n_clusters): - for k in range(n_features): - centers_squared_norms[j] += centers_new[j, k] * centers_new[j, k] + centers_squared_norms = row_norms(centers_new, squared=True) if update_centers: memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) @@ -312,76 +299,64 @@ cpdef void _lloyd_iter_chunked_sparse(X, # set number of threads to be used by openmp num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() + with nogil, parallel(num_threads=num_threads): - # thread local buffers - centers_new_chunk = calloc(n_clusters * n_features, sizeof(floating)) - weight_in_clusters_chunk = calloc(n_clusters, sizeof(floating)) for chunk_idx in prange(n_chunks): - if n_samples_r > 0 and chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_r + # remaining samples added to last chunk + if chunk_idx == n_chunks - 1: + n_samples_chunk_eff = n_samples_chunk + n_samples_r else: n_samples_chunk_eff = n_samples_chunk + start = chunk_idx * n_samples_chunk + end = start + n_samples_chunk_eff + _update_chunk_sparse( - &X_data[X_indptr[chunk_idx * n_samples_chunk]], - &X_indices[X_indptr[chunk_idx * n_samples_chunk]], - &X_indptr[chunk_idx * n_samples_chunk], - &sample_weight[chunk_idx * n_samples_chunk], - &x_squared_norms[chunk_idx * n_samples_chunk], - ¢ers_old[0, 0], - centers_new_chunk, - ¢ers_squared_norms[0], - weight_in_clusters_chunk, - &labels[chunk_idx * n_samples_chunk], - n_samples_chunk_eff, - n_clusters, - n_features, + X_data[X_indptr[start]: X_indptr[end]], + X_indices[X_indptr[start]: X_indptr[end]], + X_indptr[start: end], + sample_weight[start: end], + x_squared_norms[start: end], + centers_old, + centers_new, + centers_squared_norms, + weight_in_clusters, + labels[start: end], update_centers) - # reduction from local buffers. The gil is necessary for that to avoid - # race conditions. - if update_centers: - with gil: - for j in range(n_clusters): - weight_in_clusters[j] += weight_in_clusters_chunk[j] - for k in range(n_features): - centers_new[j, k] += centers_new_chunk[j * n_features + k] - - free(weight_in_clusters_chunk) - free(centers_new_chunk) - if update_centers: - _relocate_empty_clusters_sparse(X_data, X_indices, X_indptr, - sample_weight, centers_new, - weight_in_clusters, labels) - - _mean_and_center_shift(centers_old, centers_new, weight_in_clusters, - center_shift) - - -cdef void _update_chunk_sparse(floating *X_data, - int *X_indices, - int *X_indptr, - floating *sample_weight, - floating *x_squared_norms, - floating *centers_old, - floating *centers_new, - floating *centers_squared_norms, - floating *weight_in_cluster, - int *labels, - int n_samples, - int n_clusters, - int n_features, + _relocate_empty_clusters_sparse( + X_data, X_indices, X_indptr, sample_weight, + centers_new, weight_in_clusters, labels) + + _mean_and_center_shift( + centers_old, centers_new, weight_in_clusters, center_shift) + + +cdef void _update_chunk_sparse(floating[::1] X_data, + int[::1] X_indices, + int[::1] X_indptr, + floating[::1] sample_weight, + floating[::1] x_squared_norms, + floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[::1] centers_squared_norms, + floating[::1] weight_in_clusters, + int[::1] labels, bint update_centers) nogil: """K-means combined EM step for one data chunk - + Compute the partial contribution of a single data chunk to the labels and centers. """ - cdef: + cdef: + int n_samples = labels.shape[0] + int n_clusters = centers_old.shape[0] + int n_features = centers_old.shape[1] + floating sq_dist, min_sq_dist - int i, j, k, best_cluster + int i, j, k, label floating max_floating = FLT_MAX if floating is float else DBL_MAX int s = X_indptr[0] @@ -390,13 +365,13 @@ cdef void _update_chunk_sparse(floating *X_data, # multiplication is available. for i in range(n_samples): min_sq_dist = max_floating - best_cluster = 0 + label = 0 for j in range(n_clusters): sq_dist = 0.0 for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): - sq_dist += centers_old[j * n_features + X_indices[k]] * X_data[k] - + sq_dist += centers_old[j, X_indices[k]] * X_data[k] + # Instead of computing the full squared distance with each cluster, # ||X - C||² = ||X||² - 2 X.C^T + ||C||², we only need to compute # the - 2 X.C^T + ||C||² term since the argmin for a given sample @@ -404,11 +379,14 @@ cdef void _update_chunk_sparse(floating *X_data, sq_dist = centers_squared_norms[j] -2 * sq_dist if sq_dist < min_sq_dist: min_sq_dist = sq_dist - best_cluster = j - - labels[i] = best_cluster - - if update_centers: - weight_in_cluster[best_cluster] += sample_weight[i] - for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): - centers_new[best_cluster * n_features + X_indices[k]] += X_data[k] * sample_weight[i] + label = j + + labels[i] = label + + if update_centers: + # The gil is necessary for that to avoid race conditions. + with gil: + for i in range(n_samples): + weight_in_clusters[labels[i]] += sample_weight[i] + for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): + centers_new[labels[i], X_indices[k]] += X_data[k] * sample_weight[i] diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 00e42e1adf27d..657a444fd268a 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -30,13 +30,15 @@ from ..utils._clibs import thread_limits_context from ..utils._joblib import effective_n_jobs from ..exceptions import ConvergenceWarning -from ._k_means import (_inertia_dense, - _inertia_sparse, - _mini_batch_update_csr) -from ._k_means_lloyd import (_lloyd_iter_chunked_dense, - _lloyd_iter_chunked_sparse) -from ._k_means_elkan import (_init_bounds, - _elkan_iter_chunked_dense) +from ._k_means import _inertia_dense +from ._k_means import _inertia_sparse +from ._k_means import _mini_batch_update_csr +from ._k_means_lloyd import _lloyd_iter_chunked_dense +from ._k_means_lloyd import _lloyd_iter_chunked_sparse +from ._k_means_elkan import _init_bounds_dense +from ._k_means_elkan import _init_bounds_sparse +from ._k_means_elkan import _elkan_iter_chunked_dense +from ._k_means_elkan import _elkan_iter_chunked_sparse ############################################################################### @@ -348,13 +350,14 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', best_labels, best_inertia, best_centers = None, None, None - if algorithm == "auto": - algorithm = "full" if sp.issparse(X) else "elkan" if algorithm == "elkan" and n_clusters == 1: warnings.warn("algorithm='elkan' doesn't make sense for a single " "cluster. Using 'full' instead.", RuntimeWarning) algorithm = "full" + if algorithm == "auto": + algorithm = "elkan" + if algorithm == "full": kmeans_single = _kmeans_single_lloyd elif algorithm == "elkan": @@ -403,8 +406,8 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, init='k-means++', verbose=False, x_squared_norms=None, random_state=None, tol=1e-4, n_jobs=None): - if sp.issparse(X): - raise TypeError("algorithm='elkan' not supported for sparse input X") + # if sp.issparse(X): + # raise TypeError("algorithm='elkan' not supported for sparse input X") random_state = check_random_state(random_state) sample_weight = _check_sample_weight(X, sample_weight) @@ -422,29 +425,37 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype) labels = np.full(n_samples, -1, dtype=np.int32) center_half_distances = euclidean_distances(centers) / 2 - distance_next_center = np.zeros(n_clusters, dtype=X.dtype) + distance_next_center = np.partition(np.asarray(center_half_distances), + kth=1, axis=0)[1] upper_bounds = np.zeros(n_samples, dtype=X.dtype) lower_bounds = np.zeros((n_samples, n_clusters), dtype=X.dtype) center_shift = np.zeros(n_clusters, dtype=X.dtype) - _init_bounds(X, centers, center_half_distances, - labels, upper_bounds, lower_bounds) + if sp.issparse(X): + init_bounds = _init_bounds_sparse + elkan_iter = _elkan_iter_chunked_sparse + _inertia = _inertia_sparse + else: + init_bounds = _init_bounds_dense + elkan_iter = _elkan_iter_chunked_dense + _inertia = _inertia_dense + + init_bounds(X, centers, center_half_distances, + labels, upper_bounds, lower_bounds) for i in range(max_iter): - # compute the closest other center of each center - distance_next_center = np.partition(np.asarray(center_half_distances), - kth=1, axis=0)[1] - - _elkan_iter_chunked_dense(X, sample_weight, centers_old, centers, - weight_in_clusters, center_half_distances, - distance_next_center, upper_bounds, - lower_bounds, labels, center_shift, n_jobs) + elkan_iter(X, sample_weight, centers_old, centers, weight_in_clusters, + center_half_distances, distance_next_center, upper_bounds, + lower_bounds, labels, center_shift, n_jobs) - # compute new pairwise distances between centers for next iterations + # compute new pairwise distances between centers and closest other + # center of each center for next iterations center_half_distances = euclidean_distances(centers) / 2 + distance_next_center = np.partition(np.asarray(center_half_distances), + kth=1, axis=0)[1] if verbose: - inertia = _inertia_dense(X, sample_weight, centers_old, labels) + inertia = _inertia(X, sample_weight, centers_old, labels) print("Iteration {0}, inertia {1}" .format(i, inertia)) center_shift_tot = (center_shift**2).sum() @@ -456,13 +467,12 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, break # rerun E-step so that predicted labels match cluster centers - _elkan_iter_chunked_dense(X, sample_weight, centers, centers, - weight_in_clusters, center_half_distances, - distance_next_center, upper_bounds, - lower_bounds, labels, center_shift, n_jobs, - update_centers=False) + elkan_iter(X, sample_weight, centers, centers, weight_in_clusters, + center_half_distances, distance_next_center, upper_bounds, + lower_bounds, labels, center_shift, n_jobs, + update_centers=False) - inertia = _inertia_dense(X, sample_weight, centers, labels) + inertia = _inertia(X, sample_weight, centers, labels) return labels, inertia, centers, i + 1 diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 9e3d1271d3c70..289540f8ca93d 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -46,10 +46,8 @@ X_csr = sp.csr_matrix(X) -@pytest.mark.parametrize("representation, algo", - [('dense', 'full'), - ('dense', 'elkan'), - ('sparse', 'full')]) +@pytest.mark.parametrize("representation", ['dense', 'sparse']) +@pytest.mark.parametrize("algo", ['full', 'elkan']) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_kmeans_results(representation, algo, dtype): # cheks that kmeans works as intended @@ -92,6 +90,29 @@ def test_elkan_results(distribution): assert_array_equal(km_elkan.labels_, km_full.labels_) +@pytest.mark.parametrize('distribution', ['normal', 'blobs']) +def test_elkan_results_sparse(distribution): + # check that results are identical between lloyd and elkan algorithms + # with sparse input + rnd = np.random.RandomState(0) + if distribution is 'normal': + X = sp.random(100, 100, density=0.1, format='csr', random_state=rnd) + X.data = rnd.randn(len(X.data)) + else: + X, _ = make_blobs(n_samples=100, n_features=100, random_state=rnd) + X = sp.csr_matrix(X) + + km_full = KMeans(algorithm='full', n_clusters=5, random_state=0, n_init=1) + km_elkan = KMeans(algorithm='elkan', n_clusters=5, + random_state=0, n_init=1) + + km_full.fit(X) + km_elkan.fit(X) + assert_array_almost_equal(km_elkan.cluster_centers_, + km_full.cluster_centers_) + assert_array_equal(km_elkan.labels_, km_full.labels_) + + def test_labels_assignment_and_inertia(): # pure numpy implementation as easily auditable reference gold # implementation @@ -311,20 +332,17 @@ def test_k_means_fit_predict(algo, dtype, constructor, seed, max_iter, tol): # There's a very small chance of failure with elkan on unstructured dataset # because predict method uses fast euclidean distances computation which # may cause small numerical instabilities. - if not (algo == 'elkan' and constructor is sp.csr_matrix): - rng = np.random.RandomState(seed) + X = make_blobs(n_samples=1000, n_features=10, centers=10, + random_state=seed)[0].astype(dtype, copy=False) + X = constructor(X) - X = make_blobs(n_samples=1000, n_features=10, centers=10, - random_state=rng)[0].astype(dtype, copy=False) - X = constructor(X) + kmeans = KMeans(algorithm=algo, n_clusters=10, random_state=seed, + tol=tol, max_iter=max_iter, n_jobs=1) - kmeans = KMeans(algorithm=algo, n_clusters=10, random_state=seed, - tol=tol, max_iter=max_iter, n_jobs=1) + labels_1 = kmeans.fit(X).predict(X) + labels_2 = kmeans.fit_predict(X) - labels_1 = kmeans.fit(X).predict(X) - labels_2 = kmeans.fit_predict(X) - - assert_array_equal(labels_1, labels_2) + assert_array_equal(labels_1, labels_2) def test_mb_kmeans_verbose(): @@ -695,11 +713,6 @@ def test_k_means_function(): assert_raises(ValueError, k_means, X, n_clusters=X.shape[0] + 1, sample_weight=None) - # kmeans for algorithm='elkan' raises TypeError on sparse matrix - assert_raise_message(TypeError, "algorithm='elkan' not supported for " - "sparse input X", k_means, X=X_csr, n_clusters=2, - sample_weight=None, algorithm="elkan") - def test_x_squared_norms_init_centroids(): # Test that x_squared_norms can be None in _init_centroids From 84617128d9dd760cdf63fb2ae96c82d8e3e28cd7 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Tue, 12 Feb 2019 18:21:07 +0100 Subject: [PATCH 057/163] refactor, use memviews more, add sparse elkan --- sklearn/cluster/_k_means.pxd | 41 +-- sklearn/cluster/_k_means.pyx | 109 ++++-- sklearn/cluster/_k_means_elkan.pyx | 468 ++++++++++++++++++++------ sklearn/cluster/_k_means_lloyd.pyx | 314 ++++++++--------- sklearn/cluster/k_means_.py | 70 ++-- sklearn/cluster/tests/test_k_means.py | 53 +-- 6 files changed, 667 insertions(+), 388 deletions(-) diff --git a/sklearn/cluster/_k_means.pxd b/sklearn/cluster/_k_means.pxd index 13b65491b8bae..a005250ad37e2 100644 --- a/sklearn/cluster/_k_means.pxd +++ b/sklearn/cluster/_k_means.pxd @@ -5,29 +5,18 @@ from cython cimport floating cimport numpy as np -cdef void _relocate_empty_clusters_dense( - np.ndarray[floating, ndim=2, mode='c'], - floating[::1], - floating[:, ::1], - floating[::1], - int[::1] -) - - -cdef void _relocate_empty_clusters_sparse( - floating[::1], - int[::1], - int[::1], - floating[::1], - floating[:, ::1], - floating[::1], - int[::1] -) - - -cdef void _mean_and_center_shift( - floating[:, ::1], - floating[:, ::1], - floating[::1], - floating[::1] -) +cdef floating _euclidean_dense_dense(floating*, floating*, int, bint) nogil + +cdef floating _euclidean_sparse_dense(floating[::1], int[::1], floating[::1], + floating, bint) nogil + +cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'], + floating[::1], floating[:, ::1], + floating[::1], int[::1]) + +cdef void _relocate_empty_clusters_sparse(floating[::1], int[::1], int[::1], + floating[::1], floating[:, ::1], + floating[::1], int[::1]) + +cdef void _mean_and_center_shift(floating[:, ::1], floating[:, ::1], + floating[::1], floating[::1]) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 4fb4f54a5e82b..600bda9256780 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -1,4 +1,4 @@ -# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True +# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False # Profiling is enabled by default as the overhead does not seem to be # measurable on this specific use case. @@ -14,6 +14,8 @@ cimport cython from cython cimport floating from libc.math cimport sqrt +from ..utils.extmath import row_norms + np.import_array() @@ -22,29 +24,76 @@ ctypedef np.float64_t DOUBLE ctypedef np.int32_t INT +cdef floating _euclidean_dense_dense(floating* a, + floating* b, + int n_features, + bint squared) nogil: + """Euclidean distance between a dense and b dense""" + cdef: + int i + int n = n_features // 4 + int rem = n_features % 4 + floating result = 0 + + for i in range(n): + result += ((a[0] - b[0]) * (a[0] - b[0]) + +(a[1] - b[1]) * (a[1] - b[1]) + +(a[2] - b[2]) * (a[2] - b[2]) + +(a[3] - b[3]) * (a[3] - b[3])) + a += 4; b += 4 + + for i in range(rem): + result += (a[i] - b[i]) * (a[i] - b[i]) + + if not squared: result = sqrt(result) + + return result + + +cdef floating _euclidean_sparse_dense(floating[::1] a_data, + int[::1] a_indices, + floating[::1] b, + floating b_squared_norm, + bint squared) nogil: + """Euclidean distance between a sparse and b dense""" + cdef: + int nnz = len(a_indices) + int i + floating tmp = 0.0 + floating result = 0.0 + + for i in range(nnz): + tmp = a_data[i] - b[a_indices[i]] + result += tmp * tmp - b[a_indices[i]] * b[a_indices[i]] + + result += b_squared_norm + + if not squared: result = sqrt(result) + + return result + + cpdef floating _inertia_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[::1] sample_weight, floating[:, ::1] centers, int[::1] labels): """Compute inertia for dense input data - Sum of squared distance between each sample and it's assigned center. + Sum of squared distance between each sample and its assigned center. """ cdef: int n_samples = X.shape[0] int n_features = X.shape[1] - int i, j, k - floating tmp, sample_inertia + int i, j + floating sq_dist = 0.0 floating inertia = 0.0 for i in range(n_samples): j = labels[i] - sample_inertia = 0.0 - for k in range(n_features): - tmp = X[i, k] - centers[j, k] - sample_inertia += tmp * tmp - inertia += sample_inertia * sample_weight[i] + sq_dist = _euclidean_dense_dense(&X[i, 0], ¢ers[j, 0], + n_features, True) + inertia += sq_dist * sample_weight[i] return inertia @@ -55,35 +104,29 @@ cpdef floating _inertia_sparse(X, int[::1] labels): """Compute inertia for sparse input data - Sum of squared distance between each sample and it's assigned center. + Sum of squared distance between each sample and its assigned center. """ cdef: floating[::1] X_data = X.data int[::1] X_indices = X.indices int[::1] X_indptr = X.indptr - int n_samples = X_indptr.shape[0] - 1 - int n_features = centers.shape[1] - int i, j, k - int row_ptr, nz_len, nz_ptr - floating tmp, sample_inertia + int n_samples = X.shape[0] + int n_features = X.shape[1] + int i, j + floating sq_dist = 0.0 floating inertia = 0.0 + + floating[::1] center_squared_norms = row_norms(centers, squared=True) for i in range(n_samples): j = labels[i] - sample_inertia = 0.0 - row_ptr = X_indptr[i] - nz_len = X_indptr[i + 1] - X_indptr[i] - nz_ptr = 0 - for k in range(n_features): - if nz_ptr < nz_len and k == X_indices[row_ptr + nz_ptr]: - tmp = X_data[row_ptr + nz_ptr] - centers[j, k] - nz_ptr += 1 - else: - tmp = - centers[j, k] - sample_inertia += tmp * tmp - inertia += sample_inertia * sample_weight[i] + sq_dist = _euclidean_sparse_dense( + X_data[X_indptr[i]: X_indptr[i + 1]], + X_indices[X_indptr[i]: X_indptr[i + 1]], + centers[j], center_squared_norms[j], True) + inertia += sq_dist * sample_weight[i] return inertia @@ -93,9 +136,9 @@ cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] floating[:, ::1] centers, floating[::1] weight_in_clusters, int[::1] labels): - """Relocate centers which have no sample assigned to them""" + """Relocate centers which have no sample assigned to them.""" cdef: - int[::1] empty_clusters = np.where(np.equal(weight_in_clusters,0))[0].astype(np.int32) + int[::1] empty_clusters = np.where(np.equal(weight_in_clusters, 0))[0].astype(np.int32) int n_empty = empty_clusters.shape[0] if n_empty == 0: @@ -135,14 +178,14 @@ cdef void _relocate_empty_clusters_sparse(floating[::1] X_data, floating[:, ::1] centers, floating[::1] weight_in_clusters, int[::1] labels): - """Relocate centers which have no sample assigned to them""" + """Relocate centers which have no sample assigned to them.""" cdef: - int[::1] empty_clusters = np.where(np.equal(weight_in_clusters,0))[0].astype(np.int32) + int[::1] empty_clusters = np.where(np.equal(weight_in_clusters, 0))[0].astype(np.int32) int n_empty = empty_clusters.shape[0] if n_empty == 0: return - + cdef: int n_samples = X_indptr.shape[0] - 1 floating x @@ -183,7 +226,7 @@ cdef void _mean_and_center_shift(floating[:, ::1] centers_old, floating[:, ::1] centers_new, floating[::1] weight_in_clusters, floating[::1] center_shift): - """Average new centers wrt weights and compute center shift""" + """Average new centers wrt weights and compute center shift.""" cdef: int n_clusters = centers_old.shape[0] int n_features = centers_old.shape[1] diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index efdb104ade822..4318a82842c88 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -1,4 +1,4 @@ -# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True +# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False # cython: language_level=3 # # Author: Andreas Mueller @@ -15,40 +15,23 @@ from libc.math cimport sqrt from libc.stdlib cimport calloc, free from libc.string cimport memset, memcpy -from ..metrics import euclidean_distances -from ._k_means cimport _relocate_empty_clusters_dense, _mean_and_center_shift +from ..utils.extmath import row_norms +from ._k_means cimport _relocate_empty_clusters_dense +from ._k_means cimport _relocate_empty_clusters_sparse +from ._k_means cimport _mean_and_center_shift +from ._k_means cimport _euclidean_dense_dense +from ._k_means cimport _euclidean_sparse_dense np.import_array() -cdef floating euclidean_dist(floating* a, floating* b, int n_features) nogil: - """Euclidean distance between a and b, optimized for vectorization""" - cdef: - int i - int n = n_features // 4 - int rem = n_features % 4 - floating result = 0 - - for i in range(n): - result += ((a[0] - b[0]) * (a[0] - b[0]) - +(a[1] - b[1]) * (a[1] - b[1]) - +(a[2] - b[2]) * (a[2] - b[2]) - +(a[3] - b[3]) * (a[3] - b[3])) - a += 4; b += 4 - - for i in range(rem): - result += (a[i] - b[i]) * (a[i] - b[i]) - - return sqrt(result) - - -cpdef _init_bounds(np.ndarray[floating, ndim=2, mode='c'] X, - floating[:, ::1] centers, - floating[:, ::1] center_half_distances, - int[::1] labels, - floating[::1] upper_bounds, - floating[:, ::1] lower_bounds): +cpdef _init_bounds_dense(np.ndarray[floating, ndim=2, mode='c'] X, + floating[:, ::1] centers, + floating[:, ::1] center_half_distances, + int[::1] labels, + floating[::1] upper_bounds, + floating[:, ::1] lower_bounds): """Initialize upper and lower bounds for each sample. Given X, centers and the pairwise distances divided by 2.0 between the @@ -96,11 +79,55 @@ cpdef _init_bounds(np.ndarray[floating, ndim=2, mode='c'] X, for i in range(n_samples): best_cluster = 0 - min_dist = euclidean_dist(&X[i, 0], ¢ers[0, 0], n_features) + min_dist = _euclidean_dense_dense(&X[i, 0], ¢ers[0, 0], + n_features, False) lower_bounds[i, 0] = min_dist for j in range(1, n_clusters): if min_dist > center_half_distances[best_cluster, j]: - dist = euclidean_dist(&X[i, 0], ¢ers[j, 0], n_features) + dist = _euclidean_dense_dense(&X[i, 0], ¢ers[j, 0], + n_features, False) + lower_bounds[i, j] = dist + if dist < min_dist: + min_dist = dist + best_cluster = j + labels[i] = best_cluster + upper_bounds[i] = min_dist + + +cpdef _init_bounds_sparse(X, + floating[:, ::1] centers, + floating[:, ::1] center_half_distances, + int[::1] labels, + floating[::1] upper_bounds, + floating[:, ::1] lower_bounds): + cdef: + int n_samples = X.shape[0] + int n_clusters = centers.shape[0] + int n_features = X.shape[1] + + floating[::1] X_data = X.data + int[::1] X_indices = X.indices + int[::1] X_indptr = X.indptr + + floating min_dist, dist + int best_cluster, i, j + + floating[::1] centers_squared_norms = row_norms(centers, squared=True) + + for i in range(n_samples): + best_cluster = 0 + min_dist = _euclidean_sparse_dense( + X_data[X_indptr[i]: X_indptr[i + 1]], + X_indices[X_indptr[i]: X_indptr[i + 1]], + centers[0], centers_squared_norms[0], False) + + lower_bounds[i, 0] = min_dist + for j in range(1, n_clusters): + if min_dist > center_half_distances[best_cluster, j]: + dist = _euclidean_sparse_dense( + X_data[X_indptr[i]: X_indptr[i + 1]], + X_indices[X_indptr[i]: X_indptr[i + 1]], + centers[j], centers_squared_norms[j], False) lower_bounds[i, j] = dist if dist < min_dist: min_dist = dist @@ -120,9 +147,9 @@ cpdef void _elkan_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[:, ::1] lower_bounds, int[::1] labels, floating[::1] center_shift, - int n_jobs = -1, - bint update_centers = True): - """Single interation of K-means elkan algorithm + int n_jobs=-1, + bint update_centers=True): + """Single iteration of K-means elkan algorithm Update labels and centers (inplace), for one iteration, distributed over data chunks. @@ -153,7 +180,7 @@ shape (n_clusters, n_clusters) distance_next_center : {float32, float64} array-like, shape (n_clusters,) Distance between each center it's closest center. - + upper_bounds : {float32, float64} array-like, shape (n_samples,) Upper bound for the distance between each sample and it's center, updated inplace. @@ -189,17 +216,15 @@ shape (n_clusters, n_clusters) int n_chunks = n_samples // n_samples_chunk int n_samples_r = n_samples % n_samples_chunk int chunk_idx, n_samples_chunk_eff + int start, end int num_threads int i, j, k - int label - floating alpha, tmp, x - - floating *centers_new_chunk - floating *weight_in_clusters_chunk - # count remainder chunk in total number of chunks - n_chunks += n_samples != n_chunks * n_samples_chunk + # If n_samples < 256 there's still one chunk of size n_samples_r + if n_chunks == 0: + n_chunks = 1 + n_samples_chunk = 0 # re-initialize all arrays at each iteration if update_centers: @@ -211,50 +236,262 @@ shape (n_clusters, n_clusters) num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() with nogil, parallel(num_threads=num_threads): - # thread local buffers - centers_new_chunk = calloc(n_clusters * n_features, sizeof(floating)) - weight_in_clusters_chunk = calloc(n_clusters, sizeof(floating)) - + for chunk_idx in prange(n_chunks): - if n_samples_r > 0 and chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_r + # remaining samples added to last chunk + if chunk_idx == n_chunks - 1: + n_samples_chunk_eff = n_samples_chunk + n_samples_r else: n_samples_chunk_eff = n_samples_chunk - _update_chunk( - &X[chunk_idx * n_samples_chunk, 0], - &sample_weight[chunk_idx * n_samples_chunk], - ¢ers_old[0, 0], - centers_new_chunk, - ¢er_half_distances[0, 0], - &distance_next_center[0], - weight_in_clusters_chunk, - &labels[chunk_idx * n_samples_chunk], - &upper_bounds[chunk_idx * n_samples_chunk], - &lower_bounds[chunk_idx * n_samples_chunk, 0], - n_samples_chunk_eff, - n_clusters, - n_features, + start = chunk_idx * n_samples_chunk + end = start + n_samples_chunk_eff + + _update_chunk_dense( + &X[start, 0], + sample_weight[start: end], + centers_old, + centers_new, + center_half_distances, + distance_next_center, + weight_in_clusters, + labels[start: end], + upper_bounds[start: end], + lower_bounds[start: end], update_centers) - # reduction from local buffers. The gil is necessary for that to avoid - # race conditions. - if update_centers: - with gil: - for j in range(n_clusters): - weight_in_clusters[j] += weight_in_clusters_chunk[j] - for k in range(n_features): - centers_new[j, k] += centers_new_chunk[j * n_features + k] + if update_centers: + _relocate_empty_clusters_dense( + X, sample_weight, centers_new, weight_in_clusters, labels) + + _mean_and_center_shift( + centers_old, centers_new, weight_in_clusters, center_shift) + + # update lower and upper bounds + for i in range(n_samples): + upper_bounds[i] += center_shift[labels[i]] + + for j in range(n_clusters): + lower_bounds[i, j] -= center_shift[j] + if lower_bounds[i, j] < 0: + lower_bounds[i, j] = 0 + + +cdef void _update_chunk_dense(floating *X, + floating[::1] sample_weight, + floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[:, ::1] center_half_distances, + floating[::1] distance_next_center, + floating[::1] weight_in_clusters, + int[::1] labels, + floating[::1] upper_bounds, + floating[:, ::1] lower_bounds, + bint update_centers) nogil: + """K-means combined EM step for one data chunk + + Compute the partial contribution of a single data chunk to the labels and + centers. + """ + cdef: + int n_samples = labels.shape[0] + int n_clusters = centers_old.shape[0] + int n_features = centers_old.shape[1] + + floating upper_bound, distance + int i, j, k, label + + for i in range(n_samples): + upper_bound = upper_bounds[i] + bounds_tight = 0 + label = labels[i] + + # Next center is not far away from the currently assigned center. + # Sample might need to be assigned to another center. + if not distance_next_center[label] >= upper_bound: + + for j in range(n_clusters): + + # If this holds, then center_index is a good candidate for the + # sample to be relabelled, and we need to confirm this by + # recomputing the upper and lower bounds. + if (j != label + and (upper_bound > lower_bounds[i, j]) + and (upper_bound > center_half_distances[label, j])): + + # Recompute upper bound by calculating the actual distance + # between the sample and it's current assigned center. + if not bounds_tight: + upper_bound = _euclidean_dense_dense( + X + i * n_features, ¢ers_old[label, 0], n_features, False) + lower_bounds[i, label] = upper_bound + bounds_tight = 1 + + # If the condition still holds, then compute the actual + # distance between the sample and center. If this is less + # than the previous distance, reassign label. + if (upper_bound > lower_bounds[i, j] + or (upper_bound > center_half_distances[label, j])): + + distance = _euclidean_dense_dense( + X + i * n_features, ¢ers_old[j, 0], n_features, False) + lower_bounds[i, j] = distance + if distance < upper_bound: + label = j + upper_bound = distance + + labels[i] = label + upper_bounds[i] = upper_bound + + if update_centers: + # The gil is necessary for that to avoid race conditions. + with gil: + for i in range(n_samples): + weight_in_clusters[labels[i]] += sample_weight[i] + for k in range(n_features): + centers_new[labels[i], k] += X[i * n_features + k] * sample_weight[i] + + +cpdef void _elkan_iter_chunked_sparse(X, + floating[::1] sample_weight, + floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[::1] weight_in_clusters, + floating[:, ::1] center_half_distances, + floating[::1] distance_next_center, + floating[::1] upper_bounds, + floating[:, ::1] lower_bounds, + int[::1] labels, + floating[::1] center_shift, + int n_jobs=-1, + bint update_centers=True): + """Single iteration of K-means elkan algorithm with sparse input + + Update labels and centers (inplace), for one iteration, distributed + over data chunks. + + Parameters + ---------- + X : {float32, float64} CSR matrix, shape (n_samples, n_features) + The observations to cluster. + + sample_weight : {float32, float64} array-like, shape (n_samples,) + The weights for each observation in X. + + centers_old : {float32, float64} array-like, shape (n_clusters, n_features) + Centers before previous iteration, placeholder for the centers after + previous iteration. + + centers_new : {float32, float64} array-like, shape (n_clusters, n_features) + Centers after previous iteration, placeholder for the new centers + computed during this iteration. + + weight_in_clusters : {float32, float64} array-like, shape (n_clusters,) + Placeholder for the sums of the weights of every observation assigned + to each center. + + center_half_distances : {float32, float64} array-like, \ +shape (n_clusters, n_clusters) + Half pairwise distances between centers. + + distance_next_center : {float32, float64} array-like, shape (n_clusters,) + Distance between each center it's closest center. + + upper_bounds : {float32, float64} array-like, shape (n_samples,) + Upper bound for the distance between each sample and it's center, + updated inplace. + + lower_bounds : {float32, float64} array-like, shape (n_samples, n_clusters) + Lower bound for the distance between each sample and each center, + updated inplace. + + labels : int array-like, shape (n_samples,) + labels assignment. + + center_shift : {float32, float64} array-like, shape (n_clusters,) + Distance between old and new centers. + + n_jobs : int + The number of threads to be used by openmp. If -1, openmp will use as + many as possible. - free(weight_in_clusters_chunk) - free(centers_new_chunk) + update_centers : bool + - If True, the labels and the new centers will be computed, i.e. runs + the E-step and the M-step of the algorithm. + - If False, only the labels will be computed, i.e runs the E-step of + the algorithm. + """ + cdef: + int n_samples = X.shape[0] + int n_features = X.shape[1] + int n_clusters = centers_new.shape[0] + + floating[::1] X_data = X.data + int[::1] X_indices = X.indices + int[::1] X_indptr = X.indptr + + # hard-coded number of samples per chunk. Splitting in chunks is + # necessary to get parallelism. Chunk size chosed to be same as lloyd's + int n_samples_chunk = 256 if n_samples > 256 else n_samples + int n_chunks = n_samples // n_samples_chunk + int n_samples_r = n_samples % n_samples_chunk + int chunk_idx, n_samples_chunk_eff + int start, end + int num_threads + + int i, j, k + + floating[::1] centers_squared_norms = row_norms(centers_new, squared=True) + # If n_samples < 256 there's still one chunk of size n_samples_r + if n_chunks == 0: + n_chunks = 1 + n_samples_chunk = 0 + + # re-initialize all arrays at each iteration if update_centers: - _relocate_empty_clusters_dense(X, sample_weight, centers_new, - weight_in_clusters, labels) + memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) + memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) + memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) - _mean_and_center_shift(centers_old, centers_new, weight_in_clusters, - center_shift) + # set number of threads to be used by openmp + num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() + + with nogil, parallel(num_threads=num_threads): + + for chunk_idx in prange(n_chunks): + # remaining samples added to last chunk + if chunk_idx == n_chunks - 1: + n_samples_chunk_eff = n_samples_chunk + n_samples_r + else: + n_samples_chunk_eff = n_samples_chunk + + start = chunk_idx * n_samples_chunk + end = start + n_samples_chunk_eff + + _update_chunk_sparse( + X_data[X_indptr[start]: X_indptr[end]], + X_indices[X_indptr[start]: X_indptr[end]], + X_indptr[start: end], + sample_weight[start: end], + centers_old, + centers_new, + centers_squared_norms, + center_half_distances, + distance_next_center, + weight_in_clusters, + labels[start: end], + upper_bounds[start: end], + lower_bounds[start: end], + update_centers) + + if update_centers: + _relocate_empty_clusters_sparse( + X_data, X_indices, X_indptr, sample_weight, + centers_new, weight_in_clusters, labels) + + _mean_and_center_shift( + centers_old, centers_new, weight_in_clusters, center_shift) # update lower and upper bounds for i in range(n_samples): @@ -266,28 +503,33 @@ shape (n_clusters, n_clusters) lower_bounds[i, j] = 0 -cdef void _update_chunk(floating *X, - floating *sample_weight, - floating *centers_old, - floating *centers_new, - floating *center_half_distances, - floating *distance_next_center, - floating *weight_in_clusters, - int *labels, - floating *upper_bounds, - floating *lower_bounds, - int n_samples, - int n_clusters, - int n_features, - bint update_centers) nogil: - """K-means step for one data chunk using elkan algorithm - +cdef void _update_chunk_sparse(floating[::1] X_data, + int[::1] X_indices, + int[::1] X_indptr, + floating[::1] sample_weight, + floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[::1] centers_squared_norms, + floating[:, ::1] center_half_distances, + floating[::1] distance_next_center, + floating[::1] weight_in_clusters, + int[::1] labels, + floating[::1] upper_bounds, + floating[:, ::1] lower_bounds, + bint update_centers) nogil: + """K-means combined EM step for one data chunk + Compute the partial contribution of a single data chunk to the labels and centers. """ cdef: + int n_samples = labels.shape[0] + int n_clusters = centers_old.shape[0] + int n_features = centers_old.shape[1] + floating upper_bound, distance int i, j, k, label + int s = X_indptr[0] for i in range(n_samples): upper_bound = upper_bounds[i] @@ -304,28 +546,29 @@ cdef void _update_chunk(floating *X, # sample to be relabelled, and we need to confirm this by # recomputing the upper and lower bounds. if (j != label - and (upper_bound > lower_bounds[i * n_clusters + j]) - and (upper_bound > center_half_distances[label * n_clusters + j])): + and (upper_bound > lower_bounds[i, j]) + and (upper_bound > center_half_distances[label, j])): # Recompute upper bound by calculating the actual distance # between the sample and it's current assigned center. if not bounds_tight: - upper_bound = euclidean_dist(X + i * n_features, - centers_old + label * n_features, - n_features) - lower_bounds[i * n_clusters + label] = upper_bound + upper_bound = _euclidean_sparse_dense( + X_data[X_indptr[i] - s: X_indptr[i + 1] -s], + X_indices[X_indptr[i] -s: X_indptr[i + 1] -s], + centers_old[label], centers_squared_norms[label], False) + lower_bounds[i, label] = upper_bound bounds_tight = 1 # If the condition still holds, then compute the actual # distance between the sample and center. If this is less - #than the previous distance, reassign label. - if (upper_bound > lower_bounds[i * n_clusters + j] - or (upper_bound > center_half_distances[label * n_clusters + j])): - - distance = euclidean_dist(X + i * n_features, - centers_old + j * n_features, - n_features) - lower_bounds[i * n_clusters + j] = distance + # than the previous distance, reassign label. + if (upper_bound > lower_bounds[i, j] + or (upper_bound > center_half_distances[label, j])): + distance = _euclidean_sparse_dense( + X_data[X_indptr[i] - s: X_indptr[i + 1] -s], + X_indices[X_indptr[i] -s: X_indptr[i + 1] -s], + centers_old[j], centers_squared_norms[j], False) + lower_bounds[i, j] = distance if distance < upper_bound: label = j upper_bound = distance @@ -333,7 +576,10 @@ cdef void _update_chunk(floating *X, labels[i] = label upper_bounds[i] = upper_bound - if update_centers: - weight_in_clusters[label] += sample_weight[i] - for k in range(n_features): - centers_new[label * n_features + k] += X[i * n_features + k] * sample_weight[i] \ No newline at end of file + if update_centers: + # The gil is necessary for that to avoid race conditions. + with gil: + for i in range(n_samples): + weight_in_clusters[labels[i]] += sample_weight[i] + for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): + centers_new[labels[i], X_indices[k]] += X_data[k] * sample_weight[i] diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index 661f6771e9a5e..e9e44ded79a5c 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -1,4 +1,4 @@ -# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True +# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False # cython: language_level=3 # # Licence: BSD 3 clause @@ -9,15 +9,16 @@ cimport openmp from cython cimport floating from cython.parallel import prange, parallel from libc.math cimport sqrt -from libc.stdlib cimport calloc, free +from libc.stdlib cimport malloc, free from libc.string cimport memset, memcpy from libc.float cimport DBL_MAX, FLT_MAX +from ..utils.extmath import row_norms from ..utils._cython_blas cimport _gemm from ..utils._cython_blas cimport RowMajor, Trans, NoTrans -from ._k_means cimport (_relocate_empty_clusters_dense, - _relocate_empty_clusters_sparse, - _mean_and_center_shift) +from ._k_means cimport _relocate_empty_clusters_dense +from ._k_means cimport _relocate_empty_clusters_sparse +from ._k_means cimport _mean_and_center_shift np.import_array() @@ -29,12 +30,12 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[:, ::1] centers_old, floating[:, ::1] centers_new, floating[::1] centers_squared_norms, - floating[::1] weight_in_clusters, + floating[::1] weight_in_clusters, int[::1] labels, floating[::1] center_shift, - int n_jobs = -1, - bint update_centers = True): - """Single interation of K-means lloyd algorithm + int n_jobs=-1, + bint update_centers=True): + """Single iteration of K-means lloyd algorithm Update labels and centers (inplace), for one iteration, distributed over data chunks. @@ -49,7 +50,7 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, x_squared_norms : {float32, float64} array-like, shape (n_samples,) Squared L2 norm of X. - + centers_old : {float32, float64} array-like, shape (n_clusters, n_features) Centers before previous iteration, placeholder for the centers after previous iteration. @@ -57,7 +58,7 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, centers_new : {float32, float64} array-like, shape (n_clusters, n_features) Centers after previous iteration, placeholder for the new centers computed during this iteration. - + centers_squared_norms : {float32, float64} array-like, shape (n_clusters,) Squared L2 norm of the centers. @@ -67,7 +68,7 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, labels : int array-like, shape (n_samples,) labels assignment. - + center_shift : {float32, float64} array-like, shape (n_clusters,) Distance between old and new centers. @@ -92,23 +93,18 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, int n_chunks = n_samples // n_samples_chunk int n_samples_r = n_samples % n_samples_chunk int chunk_idx, n_samples_chunk_eff + int start, end int num_threads int j, k - floating alpha - floating *centers_new_chunk - floating *weight_in_clusters_chunk - floating *pairwise_distances_chunk + # If n_samples < 256 there's still one chunk of size n_samples_r + if n_chunks == 0: + n_chunks = 1 + n_samples_chunk = 0 - # count remainder chunk in total number of chunks - n_chunks += n_samples != n_chunks * n_samples_chunk - # re-initialize all arrays at each iteration - memset(¢ers_squared_norms[0], 0, n_clusters * sizeof(floating)) - for j in range(n_clusters): - for k in range(n_features): - centers_squared_norms[j] += centers_new[j, k] * centers_new[j, k] + centers_squared_norms = row_norms(centers_new, squared=True) if update_centers: memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) @@ -117,75 +113,65 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, # set number of threads to be used by openmp num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() + with nogil, parallel(num_threads=num_threads): - # thread local buffers - centers_new_chunk = calloc(n_clusters * n_features, sizeof(floating)) - weight_in_clusters_chunk = calloc(n_clusters, sizeof(floating)) - pairwise_distances_chunk = calloc(n_samples_chunk * n_clusters, sizeof(floating)) - + for chunk_idx in prange(n_chunks): - if n_samples_r > 0 and chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_r + # remaining samples added to last chunk + if chunk_idx == n_chunks - 1: + n_samples_chunk_eff = n_samples_chunk + n_samples_r else: n_samples_chunk_eff = n_samples_chunk + start = chunk_idx * n_samples_chunk + end = start + n_samples_chunk_eff + _update_chunk_dense( - &X[chunk_idx * n_samples_chunk, 0], - &sample_weight[chunk_idx * n_samples_chunk], - &x_squared_norms[chunk_idx * n_samples_chunk], - ¢ers_old[0, 0], - centers_new_chunk, - ¢ers_squared_norms[0], - weight_in_clusters_chunk, - pairwise_distances_chunk, - &labels[chunk_idx * n_samples_chunk], - n_samples_chunk_eff, - n_clusters, - n_features, + &X[start, 0], + sample_weight[start: end], + x_squared_norms[start: end], + centers_old, + centers_new, + centers_squared_norms, + weight_in_clusters, + labels[start: end], update_centers) - # reduction from local buffers. The gil is necessary for that to avoid - # race conditions. - if update_centers: - with gil: - for j in range(n_clusters): - weight_in_clusters[j] += weight_in_clusters_chunk[j] - for k in range(n_features): - centers_new[j, k] += centers_new_chunk[j * n_features + k] - - free(weight_in_clusters_chunk) - free(centers_new_chunk) - free(pairwise_distances_chunk) - if update_centers: - _relocate_empty_clusters_dense(X, sample_weight, centers_new, - weight_in_clusters, labels) + _relocate_empty_clusters_dense( + X, sample_weight, centers_new, weight_in_clusters, labels) - _mean_and_center_shift(centers_old, centers_new, weight_in_clusters, - center_shift) + _mean_and_center_shift( + centers_old, centers_new, weight_in_clusters, center_shift) cdef void _update_chunk_dense(floating *X, - floating *sample_weight, - floating *x_squared_norms, - floating *centers_old, - floating *centers_new, - floating *centers_squared_norms, - floating *weight_in_clusters, - floating *pairwise_distances, - int *labels, - int n_samples, - int n_clusters, - int n_features, + floating[::1] sample_weight, + floating[::1] x_squared_norms, + floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[::1] centers_squared_norms, + floating[::1] weight_in_clusters, + int[::1] labels, bint update_centers) nogil: """K-means combined EM step for one data chunk - + Compute the partial contribution of a single data chunk to the labels and centers. """ cdef: + int n_samples = labels.shape[0] + int n_clusters = centers_old.shape[0] + int n_features = centers_old.shape[1] + floating sq_dist, min_sq_dist - int i, j, k, best_cluster + int i, j, k, label + + floating *pairwise_distances_ptr = malloc(n_samples * n_clusters * sizeof(floating)) + floating[:, ::1] pairwise_distances + + with gil: + pairwise_distances = pairwise_distances_ptr # Instead of computing the full pairwise squared distances matrix, # ||X - C||² = ||X||² - 2 X.C^T + ||C||², we only need to store @@ -193,27 +179,31 @@ cdef void _update_chunk_dense(floating *X, # depends on the centers. for i in range(n_samples): for j in range(n_clusters): - pairwise_distances[i * n_clusters + j] = centers_squared_norms[j] - + pairwise_distances[i, j] = centers_squared_norms[j] + _gemm(RowMajor, NoTrans, Trans, n_samples, n_clusters, n_features, - -2.0, X, n_features, centers_old, n_features, - 1.0, pairwise_distances, n_clusters) + -2.0, X, n_features, ¢ers_old[0, 0], n_features, + 1.0, pairwise_distances_ptr, n_clusters) for i in range(n_samples): - min_sq_dist = pairwise_distances[i * n_clusters] - best_cluster = 0 - for j in range(n_clusters): - sq_dist = pairwise_distances[i * n_clusters + j] + min_sq_dist = pairwise_distances[i, 0] + label = 0 + for j in range(1, n_clusters): + sq_dist = pairwise_distances[i, j] if sq_dist < min_sq_dist: min_sq_dist = sq_dist - best_cluster = j + label = j + labels[i] = label - labels[i] = best_cluster + free(pairwise_distances_ptr) - if update_centers: - weight_in_clusters[best_cluster] += sample_weight[i] - for k in range(n_features): - centers_new[best_cluster * n_features + k] += X[i * n_features + k] * sample_weight[i] + if update_centers: + # The gil is necessary for that to avoid race conditions. + with gil: + for i in range(n_samples): + weight_in_clusters[labels[i]] += sample_weight[i] + for k in range(n_features): + centers_new[labels[i], k] += X[i * n_features + k] * sample_weight[i] cpdef void _lloyd_iter_chunked_sparse(X, @@ -222,12 +212,12 @@ cpdef void _lloyd_iter_chunked_sparse(X, floating[:, ::1] centers_old, floating[:, ::1] centers_new, floating[::1] centers_squared_norms, - floating[::1] weight_in_clusters, + floating[::1] weight_in_clusters, int[::1] labels, floating[::1] center_shift, - int n_jobs = -1, - bint update_centers = True): - """Single interation of K-means lloyd algorithm + int n_jobs=-1, + bint update_centers=True): + """Single iteration of K-means lloyd algorithm Update labels and centers (inplace), for one iteration, distributed over data chunks. @@ -242,7 +232,7 @@ cpdef void _lloyd_iter_chunked_sparse(X, x_squared_norms : {float32, float64} array-like, shape (n_samples,) Squared L2 norm of X. - + centers_old : {float32, float64} array-like, shape (n_clusters, n_features) Centers before previous iteration, placeholder for the centers after previous iteration. @@ -250,7 +240,7 @@ cpdef void _lloyd_iter_chunked_sparse(X, centers_new : {float32, float64} array-like, shape (n_clusters, n_features) Centers after previous iteration, placeholder for the new centers computed during this iteration. - + centers_squared_norms : {float32, float64} array-like, shape (n_clusters,) Squared L2 norm of the centers. @@ -260,7 +250,7 @@ cpdef void _lloyd_iter_chunked_sparse(X, labels : int array-like, shape (n_samples,) labels assignment. - + center_shift : {float32, float64} array-like, shape (n_clusters,) Distance between old and new centers. @@ -283,7 +273,8 @@ cpdef void _lloyd_iter_chunked_sparse(X, int n_samples_chunk = 256 if n_samples > 256 else n_samples int n_chunks = n_samples // n_samples_chunk int n_samples_r = n_samples % n_samples_chunk - int chunk_idx, n_samples_chunk_eff + int chunk_idx, n_samples_chunk_eff = 0 + int start = 0, end = 0 int num_threads int j, k @@ -293,17 +284,13 @@ cpdef void _lloyd_iter_chunked_sparse(X, int[::1] X_indices = X.indices int[::1] X_indptr = X.indptr - floating *centers_new_chunk - floating *weight_in_clusters_chunk + # If n_samples < 256 there's still one chunk of size n_samples_r + if n_chunks == 0: + n_chunks = 1 + n_samples_chunk = 0 - # count remainder for total number of chunks - n_chunks += n_samples != n_chunks * n_samples_chunk - # re-initialize all arrays at each iteration - memset(¢ers_squared_norms[0], 0, n_clusters * sizeof(floating)) - for j in range(n_clusters): - for k in range(n_features): - centers_squared_norms[j] += centers_new[j, k] * centers_new[j, k] + centers_squared_norms = row_norms(centers_new, squared=True) if update_centers: memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) @@ -312,76 +299,64 @@ cpdef void _lloyd_iter_chunked_sparse(X, # set number of threads to be used by openmp num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() + with nogil, parallel(num_threads=num_threads): - # thread local buffers - centers_new_chunk = calloc(n_clusters * n_features, sizeof(floating)) - weight_in_clusters_chunk = calloc(n_clusters, sizeof(floating)) for chunk_idx in prange(n_chunks): - if n_samples_r > 0 and chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_r + # remaining samples added to last chunk + if chunk_idx == n_chunks - 1: + n_samples_chunk_eff = n_samples_chunk + n_samples_r else: n_samples_chunk_eff = n_samples_chunk + start = chunk_idx * n_samples_chunk + end = start + n_samples_chunk_eff + _update_chunk_sparse( - &X_data[X_indptr[chunk_idx * n_samples_chunk]], - &X_indices[X_indptr[chunk_idx * n_samples_chunk]], - &X_indptr[chunk_idx * n_samples_chunk], - &sample_weight[chunk_idx * n_samples_chunk], - &x_squared_norms[chunk_idx * n_samples_chunk], - ¢ers_old[0, 0], - centers_new_chunk, - ¢ers_squared_norms[0], - weight_in_clusters_chunk, - &labels[chunk_idx * n_samples_chunk], - n_samples_chunk_eff, - n_clusters, - n_features, + X_data[X_indptr[start]: X_indptr[end]], + X_indices[X_indptr[start]: X_indptr[end]], + X_indptr[start: end], + sample_weight[start: end], + x_squared_norms[start: end], + centers_old, + centers_new, + centers_squared_norms, + weight_in_clusters, + labels[start: end], update_centers) - # reduction from local buffers. The gil is necessary for that to avoid - # race conditions. - if update_centers: - with gil: - for j in range(n_clusters): - weight_in_clusters[j] += weight_in_clusters_chunk[j] - for k in range(n_features): - centers_new[j, k] += centers_new_chunk[j * n_features + k] - - free(weight_in_clusters_chunk) - free(centers_new_chunk) - if update_centers: - _relocate_empty_clusters_sparse(X_data, X_indices, X_indptr, - sample_weight, centers_new, - weight_in_clusters, labels) - - _mean_and_center_shift(centers_old, centers_new, weight_in_clusters, - center_shift) - - -cdef void _update_chunk_sparse(floating *X_data, - int *X_indices, - int *X_indptr, - floating *sample_weight, - floating *x_squared_norms, - floating *centers_old, - floating *centers_new, - floating *centers_squared_norms, - floating *weight_in_cluster, - int *labels, - int n_samples, - int n_clusters, - int n_features, + _relocate_empty_clusters_sparse( + X_data, X_indices, X_indptr, sample_weight, + centers_new, weight_in_clusters, labels) + + _mean_and_center_shift( + centers_old, centers_new, weight_in_clusters, center_shift) + + +cdef void _update_chunk_sparse(floating[::1] X_data, + int[::1] X_indices, + int[::1] X_indptr, + floating[::1] sample_weight, + floating[::1] x_squared_norms, + floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[::1] centers_squared_norms, + floating[::1] weight_in_clusters, + int[::1] labels, bint update_centers) nogil: """K-means combined EM step for one data chunk - + Compute the partial contribution of a single data chunk to the labels and centers. """ - cdef: + cdef: + int n_samples = labels.shape[0] + int n_clusters = centers_old.shape[0] + int n_features = centers_old.shape[1] + floating sq_dist, min_sq_dist - int i, j, k, best_cluster + int i, j, k, label floating max_floating = FLT_MAX if floating is float else DBL_MAX int s = X_indptr[0] @@ -390,13 +365,13 @@ cdef void _update_chunk_sparse(floating *X_data, # multiplication is available. for i in range(n_samples): min_sq_dist = max_floating - best_cluster = 0 + label = 0 for j in range(n_clusters): sq_dist = 0.0 for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): - sq_dist += centers_old[j * n_features + X_indices[k]] * X_data[k] - + sq_dist += centers_old[j, X_indices[k]] * X_data[k] + # Instead of computing the full squared distance with each cluster, # ||X - C||² = ||X||² - 2 X.C^T + ||C||², we only need to compute # the - 2 X.C^T + ||C||² term since the argmin for a given sample @@ -404,11 +379,14 @@ cdef void _update_chunk_sparse(floating *X_data, sq_dist = centers_squared_norms[j] -2 * sq_dist if sq_dist < min_sq_dist: min_sq_dist = sq_dist - best_cluster = j - - labels[i] = best_cluster - - if update_centers: - weight_in_cluster[best_cluster] += sample_weight[i] - for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): - centers_new[best_cluster * n_features + X_indices[k]] += X_data[k] * sample_weight[i] + label = j + + labels[i] = label + + if update_centers: + # The gil is necessary for that to avoid race conditions. + with gil: + for i in range(n_samples): + weight_in_clusters[labels[i]] += sample_weight[i] + for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): + centers_new[labels[i], X_indices[k]] += X_data[k] * sample_weight[i] diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 00e42e1adf27d..657a444fd268a 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -30,13 +30,15 @@ from ..utils._clibs import thread_limits_context from ..utils._joblib import effective_n_jobs from ..exceptions import ConvergenceWarning -from ._k_means import (_inertia_dense, - _inertia_sparse, - _mini_batch_update_csr) -from ._k_means_lloyd import (_lloyd_iter_chunked_dense, - _lloyd_iter_chunked_sparse) -from ._k_means_elkan import (_init_bounds, - _elkan_iter_chunked_dense) +from ._k_means import _inertia_dense +from ._k_means import _inertia_sparse +from ._k_means import _mini_batch_update_csr +from ._k_means_lloyd import _lloyd_iter_chunked_dense +from ._k_means_lloyd import _lloyd_iter_chunked_sparse +from ._k_means_elkan import _init_bounds_dense +from ._k_means_elkan import _init_bounds_sparse +from ._k_means_elkan import _elkan_iter_chunked_dense +from ._k_means_elkan import _elkan_iter_chunked_sparse ############################################################################### @@ -348,13 +350,14 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', best_labels, best_inertia, best_centers = None, None, None - if algorithm == "auto": - algorithm = "full" if sp.issparse(X) else "elkan" if algorithm == "elkan" and n_clusters == 1: warnings.warn("algorithm='elkan' doesn't make sense for a single " "cluster. Using 'full' instead.", RuntimeWarning) algorithm = "full" + if algorithm == "auto": + algorithm = "elkan" + if algorithm == "full": kmeans_single = _kmeans_single_lloyd elif algorithm == "elkan": @@ -403,8 +406,8 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, init='k-means++', verbose=False, x_squared_norms=None, random_state=None, tol=1e-4, n_jobs=None): - if sp.issparse(X): - raise TypeError("algorithm='elkan' not supported for sparse input X") + # if sp.issparse(X): + # raise TypeError("algorithm='elkan' not supported for sparse input X") random_state = check_random_state(random_state) sample_weight = _check_sample_weight(X, sample_weight) @@ -422,29 +425,37 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype) labels = np.full(n_samples, -1, dtype=np.int32) center_half_distances = euclidean_distances(centers) / 2 - distance_next_center = np.zeros(n_clusters, dtype=X.dtype) + distance_next_center = np.partition(np.asarray(center_half_distances), + kth=1, axis=0)[1] upper_bounds = np.zeros(n_samples, dtype=X.dtype) lower_bounds = np.zeros((n_samples, n_clusters), dtype=X.dtype) center_shift = np.zeros(n_clusters, dtype=X.dtype) - _init_bounds(X, centers, center_half_distances, - labels, upper_bounds, lower_bounds) + if sp.issparse(X): + init_bounds = _init_bounds_sparse + elkan_iter = _elkan_iter_chunked_sparse + _inertia = _inertia_sparse + else: + init_bounds = _init_bounds_dense + elkan_iter = _elkan_iter_chunked_dense + _inertia = _inertia_dense + + init_bounds(X, centers, center_half_distances, + labels, upper_bounds, lower_bounds) for i in range(max_iter): - # compute the closest other center of each center - distance_next_center = np.partition(np.asarray(center_half_distances), - kth=1, axis=0)[1] - - _elkan_iter_chunked_dense(X, sample_weight, centers_old, centers, - weight_in_clusters, center_half_distances, - distance_next_center, upper_bounds, - lower_bounds, labels, center_shift, n_jobs) + elkan_iter(X, sample_weight, centers_old, centers, weight_in_clusters, + center_half_distances, distance_next_center, upper_bounds, + lower_bounds, labels, center_shift, n_jobs) - # compute new pairwise distances between centers for next iterations + # compute new pairwise distances between centers and closest other + # center of each center for next iterations center_half_distances = euclidean_distances(centers) / 2 + distance_next_center = np.partition(np.asarray(center_half_distances), + kth=1, axis=0)[1] if verbose: - inertia = _inertia_dense(X, sample_weight, centers_old, labels) + inertia = _inertia(X, sample_weight, centers_old, labels) print("Iteration {0}, inertia {1}" .format(i, inertia)) center_shift_tot = (center_shift**2).sum() @@ -456,13 +467,12 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, break # rerun E-step so that predicted labels match cluster centers - _elkan_iter_chunked_dense(X, sample_weight, centers, centers, - weight_in_clusters, center_half_distances, - distance_next_center, upper_bounds, - lower_bounds, labels, center_shift, n_jobs, - update_centers=False) + elkan_iter(X, sample_weight, centers, centers, weight_in_clusters, + center_half_distances, distance_next_center, upper_bounds, + lower_bounds, labels, center_shift, n_jobs, + update_centers=False) - inertia = _inertia_dense(X, sample_weight, centers, labels) + inertia = _inertia(X, sample_weight, centers, labels) return labels, inertia, centers, i + 1 diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 9e3d1271d3c70..289540f8ca93d 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -46,10 +46,8 @@ X_csr = sp.csr_matrix(X) -@pytest.mark.parametrize("representation, algo", - [('dense', 'full'), - ('dense', 'elkan'), - ('sparse', 'full')]) +@pytest.mark.parametrize("representation", ['dense', 'sparse']) +@pytest.mark.parametrize("algo", ['full', 'elkan']) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_kmeans_results(representation, algo, dtype): # cheks that kmeans works as intended @@ -92,6 +90,29 @@ def test_elkan_results(distribution): assert_array_equal(km_elkan.labels_, km_full.labels_) +@pytest.mark.parametrize('distribution', ['normal', 'blobs']) +def test_elkan_results_sparse(distribution): + # check that results are identical between lloyd and elkan algorithms + # with sparse input + rnd = np.random.RandomState(0) + if distribution is 'normal': + X = sp.random(100, 100, density=0.1, format='csr', random_state=rnd) + X.data = rnd.randn(len(X.data)) + else: + X, _ = make_blobs(n_samples=100, n_features=100, random_state=rnd) + X = sp.csr_matrix(X) + + km_full = KMeans(algorithm='full', n_clusters=5, random_state=0, n_init=1) + km_elkan = KMeans(algorithm='elkan', n_clusters=5, + random_state=0, n_init=1) + + km_full.fit(X) + km_elkan.fit(X) + assert_array_almost_equal(km_elkan.cluster_centers_, + km_full.cluster_centers_) + assert_array_equal(km_elkan.labels_, km_full.labels_) + + def test_labels_assignment_and_inertia(): # pure numpy implementation as easily auditable reference gold # implementation @@ -311,20 +332,17 @@ def test_k_means_fit_predict(algo, dtype, constructor, seed, max_iter, tol): # There's a very small chance of failure with elkan on unstructured dataset # because predict method uses fast euclidean distances computation which # may cause small numerical instabilities. - if not (algo == 'elkan' and constructor is sp.csr_matrix): - rng = np.random.RandomState(seed) + X = make_blobs(n_samples=1000, n_features=10, centers=10, + random_state=seed)[0].astype(dtype, copy=False) + X = constructor(X) - X = make_blobs(n_samples=1000, n_features=10, centers=10, - random_state=rng)[0].astype(dtype, copy=False) - X = constructor(X) + kmeans = KMeans(algorithm=algo, n_clusters=10, random_state=seed, + tol=tol, max_iter=max_iter, n_jobs=1) - kmeans = KMeans(algorithm=algo, n_clusters=10, random_state=seed, - tol=tol, max_iter=max_iter, n_jobs=1) + labels_1 = kmeans.fit(X).predict(X) + labels_2 = kmeans.fit_predict(X) - labels_1 = kmeans.fit(X).predict(X) - labels_2 = kmeans.fit_predict(X) - - assert_array_equal(labels_1, labels_2) + assert_array_equal(labels_1, labels_2) def test_mb_kmeans_verbose(): @@ -695,11 +713,6 @@ def test_k_means_function(): assert_raises(ValueError, k_means, X, n_clusters=X.shape[0] + 1, sample_weight=None) - # kmeans for algorithm='elkan' raises TypeError on sparse matrix - assert_raise_message(TypeError, "algorithm='elkan' not supported for " - "sparse input X", k_means, X=X_csr, n_clusters=2, - sample_weight=None, algorithm="elkan") - def test_x_squared_norms_init_centroids(): # Test that x_squared_norms can be None in _init_centroids From 9ed44364919019de0d5778863596576de3824afa Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 21 Feb 2019 17:36:48 +0100 Subject: [PATCH 058/163] docstrings --- sklearn/cluster/_k_means_elkan.pyx | 61 ++++++++++++++++++++++++------ sklearn/cluster/_k_means_lloyd.pyx | 8 ++-- 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 4318a82842c88..583e54ebcb42f 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -32,7 +32,7 @@ cpdef _init_bounds_dense(np.ndarray[floating, ndim=2, mode='c'] X, int[::1] labels, floating[::1] upper_bounds, floating[:, ::1] lower_bounds): - """Initialize upper and lower bounds for each sample. + """Initialize upper and lower bounds for each sample for dense input data. Given X, centers and the pairwise distances divided by 2.0 between the centers this calculates the upper bounds and lower bounds for each sample. @@ -49,23 +49,24 @@ cpdef _init_bounds_dense(np.ndarray[floating, ndim=2, mode='c'] X, Parameters ---------- - X : nd-array, shape (n_samples, n_features) + X : {float32, float64} ndarray, shape (n_samples, n_features) The input data. - centers : nd-array, shape (n_clusters, n_features) + centers : {float32, float64} ndarray, shape (n_clusters, n_features) The cluster centers. - center_half_distances : nd-array, shape (n_clusters, n_clusters) + center_half_distances : {float32, float64} ndarray, / +shape (n_clusters, n_clusters) The half of the distance between any 2 clusters centers. - labels : nd-array, shape(n_samples) + labels : int ndarray, shape(n_samples) The label for each sample. This array is modified in place. - lower_bounds : nd-array, shape(n_samples, n_clusters) + lower_bounds : {float32, float64} ndarray, shape(n_samples, n_clusters) The lower bound on the distance between a sample and each cluster center. It is modified in place. - upper_bounds : nd-array, shape(n_samples,) + upper_bounds : {float32, float64} ndarray, shape(n_samples,) The distance of each sample from its closest cluster center. This is modified in place by the function. """ @@ -100,6 +101,44 @@ cpdef _init_bounds_sparse(X, int[::1] labels, floating[::1] upper_bounds, floating[:, ::1] lower_bounds): + """Initialize upper and lower bounds for each sample for sparse input data. + + Given X, centers and the pairwise distances divided by 2.0 between the + centers this calculates the upper bounds and lower bounds for each sample. + The upper bound for each sample is set to the distance between the sample + and the closest center. + + The lower bound for each sample is a one-dimensional array of n_clusters. + For each sample i assume that the previously assigned cluster is c1 and the + previous closest distance is dist, for a new cluster c2, the + lower_bound[i][c2] is set to distance between the sample and this new + cluster, if and only if dist > center_half_distances[c1][c2]. This prevents + computation of unnecessary distances for each sample to the clusters that + it is unlikely to be assigned to. + + Parameters + ---------- + X : csr_matrix, shape (n_samples, n_features) + The input data. + + centers : {float32, float64} ndarray, shape (n_clusters, n_features) + The cluster centers. + + center_half_distances : {float32, float64} ndarray, / +shape (n_clusters, n_clusters) + The half of the distance between any 2 clusters centers. + + labels : int ndarray, shape(n_samples) + The label for each sample. This array is modified in place. + + lower_bounds : {float32, float64} ndarray, shape(n_samples, n_clusters) + The lower bound on the distance between a sample and each cluster + center. It is modified in place. + + upper_bounds : {float32, float64} ndarray, shape(n_samples,) + The distance of each sample from its closest cluster center. This is + modified in place by the function. + """ cdef: int n_samples = X.shape[0] int n_clusters = centers.shape[0] @@ -149,7 +188,7 @@ cpdef void _elkan_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[::1] center_shift, int n_jobs=-1, bint update_centers=True): - """Single iteration of K-means elkan algorithm + """Single iteration of K-means elkan algorithm with dense input. Update labels and centers (inplace), for one iteration, distributed over data chunks. @@ -288,7 +327,7 @@ cdef void _update_chunk_dense(floating *X, floating[::1] upper_bounds, floating[:, ::1] lower_bounds, bint update_centers) nogil: - """K-means combined EM step for one data chunk + """K-means combined EM step for one dense data chunk. Compute the partial contribution of a single data chunk to the labels and centers. @@ -365,7 +404,7 @@ cpdef void _elkan_iter_chunked_sparse(X, floating[::1] center_shift, int n_jobs=-1, bint update_centers=True): - """Single iteration of K-means elkan algorithm with sparse input + """Single iteration of K-means elkan algorithm with sparse input. Update labels and centers (inplace), for one iteration, distributed over data chunks. @@ -517,7 +556,7 @@ cdef void _update_chunk_sparse(floating[::1] X_data, floating[::1] upper_bounds, floating[:, ::1] lower_bounds, bint update_centers) nogil: - """K-means combined EM step for one data chunk + """K-means combined EM step for one sparse data chunk. Compute the partial contribution of a single data chunk to the labels and centers. diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index e9e44ded79a5c..d942dacbd0687 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -35,7 +35,7 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[::1] center_shift, int n_jobs=-1, bint update_centers=True): - """Single iteration of K-means lloyd algorithm + """Single iteration of K-means lloyd algorithm with dense input. Update labels and centers (inplace), for one iteration, distributed over data chunks. @@ -154,7 +154,7 @@ cdef void _update_chunk_dense(floating *X, floating[::1] weight_in_clusters, int[::1] labels, bint update_centers) nogil: - """K-means combined EM step for one data chunk + """K-means combined EM step for one dense data chunk. Compute the partial contribution of a single data chunk to the labels and centers. @@ -217,7 +217,7 @@ cpdef void _lloyd_iter_chunked_sparse(X, floating[::1] center_shift, int n_jobs=-1, bint update_centers=True): - """Single iteration of K-means lloyd algorithm + """Single iteration of K-means lloyd algorithm with sparse input. Update labels and centers (inplace), for one iteration, distributed over data chunks. @@ -345,7 +345,7 @@ cdef void _update_chunk_sparse(floating[::1] X_data, floating[::1] weight_in_clusters, int[::1] labels, bint update_centers) nogil: - """K-means combined EM step for one data chunk + """K-means combined EM step for one sparse data chunk. Compute the partial contribution of a single data chunk to the labels and centers. From 4d93fa5fb38cc2ceb4877a87635ab2f64d279141 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 21 Feb 2019 17:57:53 +0100 Subject: [PATCH 059/163] nitpick --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index 10d7ed5eb761d..82338aebf0b26 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -88,7 +88,7 @@ artifacts: on_success: - "cp ../empty_folder/.coverage ." - - "codecov" + - codecov # Upload the generated wheel package to Rackspace - "python -m wheelhouse_uploader upload --local-folder=dist sklearn-windows-wheels" From 31a3052a4017686ef6fff05e6c1624e0f36ed88a Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 22 Feb 2019 11:23:38 +0100 Subject: [PATCH 060/163] fix euclean_sparse_dense --- sklearn/cluster/_k_means.pyx | 23 ++++++++++++----------- sklearn/cluster/_k_means_elkan.pyx | 1 + 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 600bda9256780..0fda5202a22a3 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -65,20 +65,21 @@ cdef floating _euclidean_sparse_dense(floating[::1] a_data, for i in range(nnz): tmp = a_data[i] - b[a_indices[i]] result += tmp * tmp - b[a_indices[i]] * b[a_indices[i]] - + result += b_squared_norm + if result < 0: result = 0.0 if not squared: result = sqrt(result) - + return result cpdef floating _inertia_dense(np.ndarray[floating, ndim=2, mode='c'] X, - floating[::1] sample_weight, + floating[::1] sample_weight, floating[:, ::1] centers, int[::1] labels): """Compute inertia for dense input data - + Sum of squared distance between each sample and its assigned center. """ cdef: @@ -103,7 +104,7 @@ cpdef floating _inertia_sparse(X, floating[:, ::1] centers, int[::1] labels): """Compute inertia for sparse input data - + Sum of squared distance between each sample and its assigned center. """ cdef: @@ -117,7 +118,7 @@ cpdef floating _inertia_sparse(X, floating sq_dist = 0.0 floating inertia = 0.0 - + floating[::1] center_squared_norms = row_norms(centers, squared=True) for i in range(n_samples): @@ -192,19 +193,19 @@ cdef void _relocate_empty_clusters_sparse(floating[::1] X_data, int i, j, k floating[::1] distances = np.zeros(n_samples, dtype=X_data.base.dtype) - + for i in range(n_samples): j = labels[i] for k in range(X_indptr[i], X_indptr[i + 1]): x = (X_data[k] - centers[j, X_indices[k]]) distances[i] += x * x - cdef: + cdef: int[::1] far_from_centers = np.argpartition(distances, -n_empty)[-n_empty:].astype(np.int32) int new_cluster_id, old_cluster_id, far_idx, idx floating weight - + for idx in range(n_empty): new_cluster_id = empty_clusters[idx] @@ -213,7 +214,7 @@ cdef void _relocate_empty_clusters_sparse(floating[::1] X_data, weight = sample_weight[far_idx] old_cluster_id = labels[far_idx] - + for k in range(X_indptr[far_idx], X_indptr[far_idx + 1]): centers[new_cluster_id, X_indices[k]] += X_data[k] * weight centers[old_cluster_id, X_indices[k]] -= X_data[k] * weight @@ -240,7 +241,7 @@ cdef void _mean_and_center_shift(floating[:, ::1] centers_old, alpha = 1.0 / weight_in_clusters[j] for k in range(n_features): centers_new[j, k] *= alpha - + # compute shift distance between old and new centers for j in range(n_clusters): tmp = 0 diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 4318a82842c88..20cc6e0d7bd22 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -120,6 +120,7 @@ cpdef _init_bounds_sparse(X, X_data[X_indptr[i]: X_indptr[i + 1]], X_indices[X_indptr[i]: X_indptr[i + 1]], centers[0], centers_squared_norms[0], False) + print(min_dist) lower_bounds[i, 0] = min_dist for j in range(1, n_clusters): From dda6527a3bbb540e57ac50b86974edfdf3516654 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Tue, 26 Feb 2019 14:25:49 +0100 Subject: [PATCH 061/163] fix relocate empty cluster --- sklearn/cluster/_k_means.pxd | 17 ++-- sklearn/cluster/_k_means.pyx | 74 ++++++++------- sklearn/cluster/_k_means_elkan.pyx | 130 ++++++++++++-------------- sklearn/cluster/_k_means_lloyd.pyx | 120 ++++++++++++------------ sklearn/cluster/k_means_.py | 25 ++--- sklearn/cluster/tests/test_k_means.py | 25 +++++ 6 files changed, 207 insertions(+), 184 deletions(-) diff --git a/sklearn/cluster/_k_means.pxd b/sklearn/cluster/_k_means.pxd index a005250ad37e2..385e9cbbb2ef2 100644 --- a/sklearn/cluster/_k_means.pxd +++ b/sklearn/cluster/_k_means.pxd @@ -10,13 +10,14 @@ cdef floating _euclidean_dense_dense(floating*, floating*, int, bint) nogil cdef floating _euclidean_sparse_dense(floating[::1], int[::1], floating[::1], floating, bint) nogil -cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'], - floating[::1], floating[:, ::1], - floating[::1], int[::1]) +cdef void _relocate_empty_clusters_dense( + np.ndarray[floating, ndim=2, mode='c'], floating[::1], floating[:, ::1], + floating[:, ::1], floating[::1], int[::1]) -cdef void _relocate_empty_clusters_sparse(floating[::1], int[::1], int[::1], - floating[::1], floating[:, ::1], - floating[::1], int[::1]) +cdef void _relocate_empty_clusters_sparse( + floating[::1], int[::1], int[::1], floating[::1], floating[:, ::1], + floating[:, ::1], floating[::1], int[::1]) -cdef void _mean_and_center_shift(floating[:, ::1], floating[:, ::1], - floating[::1], floating[::1]) +cdef void _average_centers(floating[:, ::1], floating[::1]) + +cdef void _center_shift(floating[:, ::1], floating[:, ::1], floating[::1]) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 18c93efeb198d..46a6f45f11573 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -119,14 +119,14 @@ cpdef floating _inertia_sparse(X, floating sq_dist = 0.0 floating inertia = 0.0 - floating[::1] center_squared_norms = row_norms(centers, squared=True) + floating[::1] centers_squared_norms = row_norms(centers, squared=True) for i in range(n_samples): j = labels[i] sq_dist = _euclidean_sparse_dense( X_data[X_indptr[i]: X_indptr[i + 1]], X_indices[X_indptr[i]: X_indptr[i + 1]], - centers[j], center_squared_norms[j], True) + centers[j], centers_squared_norms[j], True) inertia += sq_dist * sample_weight[i] return inertia @@ -134,7 +134,8 @@ cpdef floating _inertia_sparse(X, cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[::1] sample_weight, - floating[:, ::1] centers, + floating[:, ::1] centers_old, + floating[:, ::1] centers_new, floating[::1] weight_in_clusters, int[::1] labels): """Relocate centers which have no sample assigned to them.""" @@ -148,13 +149,12 @@ cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] cdef: int n_features = X.shape[1] - floating[::1] distances = ((np.asarray(X) - np.asarray(centers)[labels])**2).sum(axis=1) - - int[::1] far_from_centers = np.argpartition(distances, -n_empty)[-n_empty:].astype(np.int32) + floating[::1] distances = ((np.asarray(X) - np.asarray(centers_old)[labels])**2).sum(axis=1) + int[::1] far_from_centers = np.argpartition(distances, -n_empty)[:n_empty-1:-1].astype(np.int32) int new_cluster_id, old_cluster_id, far_idx, idx, k floating weight - + print() for idx in range(n_empty): new_cluster_id = empty_clusters[idx] @@ -165,18 +165,19 @@ cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] old_cluster_id = labels[far_idx] for k in range(n_features): - centers[new_cluster_id, k] = X[far_idx, k] * weight - centers[old_cluster_id, k] -= X[far_idx, k] * weight + centers_new[old_cluster_id, k] -= X[far_idx, k] * weight + centers_new[new_cluster_id, k] = X[far_idx, k] * weight weight_in_clusters[new_cluster_id] = weight weight_in_clusters[old_cluster_id] -= weight - + print('ok') cdef void _relocate_empty_clusters_sparse(floating[::1] X_data, int[::1] X_indices, int[::1] X_indptr, floating[::1] sample_weight, - floating[:, ::1] centers, + floating[:, ::1] centers_old, + floating[:, ::1] centers_new, floating[::1] weight_in_clusters, int[::1] labels): """Relocate centers which have no sample assigned to them.""" @@ -189,19 +190,22 @@ cdef void _relocate_empty_clusters_sparse(floating[::1] X_data, cdef: int n_samples = X_indptr.shape[0] - 1 + int n_features = centers_old.shape[1] floating x int i, j, k floating[::1] distances = np.zeros(n_samples, dtype=X_data.base.dtype) + floating[::1] centers_squared_norms = row_norms(centers_old, squared=True) for i in range(n_samples): j = labels[i] - for k in range(X_indptr[i], X_indptr[i + 1]): - x = (X_data[k] - centers[j, X_indices[k]]) - distances[i] += x * x + distances[i] = _euclidean_sparse_dense( + X_data[X_indptr[i]: X_indptr[i + 1]], + X_indices[X_indptr[i]: X_indptr[i + 1]], + centers_old[j], centers_squared_norms[j], True) cdef: - int[::1] far_from_centers = np.argpartition(distances, -n_empty)[-n_empty:].astype(np.int32) + int[::1] far_from_centers = np.argpartition(distances, -n_empty)[:n_empty-1:-1].astype(np.int32) int new_cluster_id, old_cluster_id, far_idx, idx floating weight @@ -216,39 +220,41 @@ cdef void _relocate_empty_clusters_sparse(floating[::1] X_data, old_cluster_id = labels[far_idx] for k in range(X_indptr[far_idx], X_indptr[far_idx + 1]): - centers[new_cluster_id, X_indices[k]] += X_data[k] * weight - centers[old_cluster_id, X_indices[k]] -= X_data[k] * weight + centers_new[old_cluster_id, X_indices[k]] -= X_data[k] * weight + centers_new[new_cluster_id, X_indices[k]] = X_data[k] * weight weight_in_clusters[new_cluster_id] = weight weight_in_clusters[old_cluster_id] -= weight -cdef void _mean_and_center_shift(floating[:, ::1] centers_old, - floating[:, ::1] centers_new, - floating[::1] weight_in_clusters, - floating[::1] center_shift): - """Average new centers wrt weights and compute center shift.""" +cdef void _average_centers(floating[:, ::1] centers, + floating[::1] weight_in_clusters): + """Average new centers wrt weights.""" cdef: - int n_clusters = centers_old.shape[0] - int n_features = centers_old.shape[1] - + int n_clusters = centers.shape[0] + int n_features = centers.shape[1] int j, k - floating alpha, tmp, x + floating alpha - # average new centers wrt sample weights for j in range(n_clusters): if weight_in_clusters[j] > 0: alpha = 1.0 / weight_in_clusters[j] for k in range(n_features): - centers_new[j, k] *= alpha + centers[j, k] *= alpha + + +cdef void _center_shift(floating[:, ::1] centers_old, + floating[:, ::1] centers_new, + floating[::1] center_shift): + """Compute shift between old and new centers.""" + cdef: + int n_clusters = centers_old.shape[0] + int n_features = centers_old.shape[1] + int j - # compute shift distance between old and new centers for j in range(n_clusters): - tmp = 0 - for k in range(n_features): - x = centers_new[j, k] - centers_old[j, k] - tmp += x * x - center_shift[j] = sqrt(tmp) + center_shift[j] = _euclidean_dense_dense( + ¢ers_new[j, 0], ¢ers_old[j, 0], n_features, False) def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] sample_weight, diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 583e54ebcb42f..9f573cfc1d873 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -18,9 +18,10 @@ from libc.string cimport memset, memcpy from ..utils.extmath import row_norms from ._k_means cimport _relocate_empty_clusters_dense from ._k_means cimport _relocate_empty_clusters_sparse -from ._k_means cimport _mean_and_center_shift from ._k_means cimport _euclidean_dense_dense from ._k_means cimport _euclidean_sparse_dense +from ._k_means cimport _average_centers +from ._k_means cimport _center_shift np.import_array() @@ -186,7 +187,7 @@ cpdef void _elkan_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[:, ::1] lower_bounds, int[::1] labels, floating[::1] center_shift, - int n_jobs=-1, + int n_jobs, bint update_centers=True): """Single iteration of K-means elkan algorithm with dense input. @@ -256,7 +257,7 @@ shape (n_clusters, n_clusters) int n_samples_r = n_samples % n_samples_chunk int chunk_idx, n_samples_chunk_eff int start, end - int num_threads + # int num_threads int i, j, k @@ -271,40 +272,35 @@ shape (n_clusters, n_clusters) memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) - # set number of threads to be used by openmp - num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() - - with nogil, parallel(num_threads=num_threads): - - for chunk_idx in prange(n_chunks): - # remaining samples added to last chunk - if chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_chunk + n_samples_r - else: - n_samples_chunk_eff = n_samples_chunk - - start = chunk_idx * n_samples_chunk - end = start + n_samples_chunk_eff - - _update_chunk_dense( - &X[start, 0], - sample_weight[start: end], - centers_old, - centers_new, - center_half_distances, - distance_next_center, - weight_in_clusters, - labels[start: end], - upper_bounds[start: end], - lower_bounds[start: end], - update_centers) + for chunk_idx in prange(n_chunks, nogil=True, num_threads=n_jobs): + # remaining samples added to last chunk + if chunk_idx == n_chunks - 1: + n_samples_chunk_eff = n_samples_chunk + n_samples_r + else: + n_samples_chunk_eff = n_samples_chunk + + start = chunk_idx * n_samples_chunk + end = start + n_samples_chunk_eff + + _update_chunk_dense( + &X[start, 0], + sample_weight[start: end], + centers_old, + centers_new, + center_half_distances, + distance_next_center, + weight_in_clusters, + labels[start: end], + upper_bounds[start: end], + lower_bounds[start: end], + update_centers) if update_centers: - _relocate_empty_clusters_dense( - X, sample_weight, centers_new, weight_in_clusters, labels) + _relocate_empty_clusters_dense(X, sample_weight, centers_old, + centers_new, weight_in_clusters, labels) - _mean_and_center_shift( - centers_old, centers_new, weight_in_clusters, center_shift) + _average_centers(centers_new, weight_in_clusters) + _center_shift(centers_old, centers_new, center_shift) # update lower and upper bounds for i in range(n_samples): @@ -402,7 +398,7 @@ cpdef void _elkan_iter_chunked_sparse(X, floating[:, ::1] lower_bounds, int[::1] labels, floating[::1] center_shift, - int n_jobs=-1, + int n_jobs, bint update_centers=True): """Single iteration of K-means elkan algorithm with sparse input. @@ -476,7 +472,7 @@ shape (n_clusters, n_clusters) int n_samples_r = n_samples % n_samples_chunk int chunk_idx, n_samples_chunk_eff int start, end - int num_threads + # int num_threads int i, j, k @@ -494,43 +490,41 @@ shape (n_clusters, n_clusters) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) # set number of threads to be used by openmp - num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() - - with nogil, parallel(num_threads=num_threads): - - for chunk_idx in prange(n_chunks): - # remaining samples added to last chunk - if chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_chunk + n_samples_r - else: - n_samples_chunk_eff = n_samples_chunk - - start = chunk_idx * n_samples_chunk - end = start + n_samples_chunk_eff - - _update_chunk_sparse( - X_data[X_indptr[start]: X_indptr[end]], - X_indices[X_indptr[start]: X_indptr[end]], - X_indptr[start: end], - sample_weight[start: end], - centers_old, - centers_new, - centers_squared_norms, - center_half_distances, - distance_next_center, - weight_in_clusters, - labels[start: end], - upper_bounds[start: end], - lower_bounds[start: end], - update_centers) + # num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() + + for chunk_idx in prange(n_chunks, nogil=True, num_threads=n_jobs): + # remaining samples added to last chunk + if chunk_idx == n_chunks - 1: + n_samples_chunk_eff = n_samples_chunk + n_samples_r + else: + n_samples_chunk_eff = n_samples_chunk + + start = chunk_idx * n_samples_chunk + end = start + n_samples_chunk_eff + + _update_chunk_sparse( + X_data[X_indptr[start]: X_indptr[end]], + X_indices[X_indptr[start]: X_indptr[end]], + X_indptr[start: end], + sample_weight[start: end], + centers_old, + centers_new, + centers_squared_norms, + center_half_distances, + distance_next_center, + weight_in_clusters, + labels[start: end], + upper_bounds[start: end], + lower_bounds[start: end], + update_centers) if update_centers: _relocate_empty_clusters_sparse( X_data, X_indices, X_indptr, sample_weight, - centers_new, weight_in_clusters, labels) + centers_old, centers_new, weight_in_clusters, labels) - _mean_and_center_shift( - centers_old, centers_new, weight_in_clusters, center_shift) + _average_centers(centers_new, weight_in_clusters) + _center_shift(centers_old, centers_new, center_shift) # update lower and upper bounds for i in range(n_samples): diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index d942dacbd0687..7c226224e4014 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -18,7 +18,7 @@ from ..utils._cython_blas cimport _gemm from ..utils._cython_blas cimport RowMajor, Trans, NoTrans from ._k_means cimport _relocate_empty_clusters_dense from ._k_means cimport _relocate_empty_clusters_sparse -from ._k_means cimport _mean_and_center_shift +from ._k_means cimport _average_centers, _center_shift np.import_array() @@ -33,7 +33,7 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[::1] weight_in_clusters, int[::1] labels, floating[::1] center_shift, - int n_jobs=-1, + int n_jobs, bint update_centers=True): """Single iteration of K-means lloyd algorithm with dense input. @@ -94,7 +94,7 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, int n_samples_r = n_samples % n_samples_chunk int chunk_idx, n_samples_chunk_eff int start, end - int num_threads + # int num_threads int j, k @@ -112,37 +112,35 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) # set number of threads to be used by openmp - num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() - - with nogil, parallel(num_threads=num_threads): - - for chunk_idx in prange(n_chunks): - # remaining samples added to last chunk - if chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_chunk + n_samples_r - else: - n_samples_chunk_eff = n_samples_chunk - - start = chunk_idx * n_samples_chunk - end = start + n_samples_chunk_eff - - _update_chunk_dense( - &X[start, 0], - sample_weight[start: end], - x_squared_norms[start: end], - centers_old, - centers_new, - centers_squared_norms, - weight_in_clusters, - labels[start: end], - update_centers) + # num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() + + for chunk_idx in prange(n_chunks, nogil=True, num_threads=n_jobs): + # remaining samples added to last chunk + if chunk_idx == n_chunks - 1: + n_samples_chunk_eff = n_samples_chunk + n_samples_r + else: + n_samples_chunk_eff = n_samples_chunk + + start = chunk_idx * n_samples_chunk + end = start + n_samples_chunk_eff + + _update_chunk_dense( + &X[start, 0], + sample_weight[start: end], + x_squared_norms[start: end], + centers_old, + centers_new, + centers_squared_norms, + weight_in_clusters, + labels[start: end], + update_centers) if update_centers: - _relocate_empty_clusters_dense( - X, sample_weight, centers_new, weight_in_clusters, labels) + _relocate_empty_clusters_dense(X, sample_weight, centers_old, + centers_new, weight_in_clusters, labels) - _mean_and_center_shift( - centers_old, centers_new, weight_in_clusters, center_shift) + _average_centers(centers_new, weight_in_clusters) + _center_shift(centers_old, centers_new, center_shift) cdef void _update_chunk_dense(floating *X, @@ -215,7 +213,7 @@ cpdef void _lloyd_iter_chunked_sparse(X, floating[::1] weight_in_clusters, int[::1] labels, floating[::1] center_shift, - int n_jobs=-1, + int n_jobs, bint update_centers=True): """Single iteration of K-means lloyd algorithm with sparse input. @@ -275,7 +273,7 @@ cpdef void _lloyd_iter_chunked_sparse(X, int n_samples_r = n_samples % n_samples_chunk int chunk_idx, n_samples_chunk_eff = 0 int start = 0, end = 0 - int num_threads + # int num_threads int j, k floating alpha @@ -298,40 +296,38 @@ cpdef void _lloyd_iter_chunked_sparse(X, memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) # set number of threads to be used by openmp - num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() - - with nogil, parallel(num_threads=num_threads): - - for chunk_idx in prange(n_chunks): - # remaining samples added to last chunk - if chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_chunk + n_samples_r - else: - n_samples_chunk_eff = n_samples_chunk - - start = chunk_idx * n_samples_chunk - end = start + n_samples_chunk_eff - - _update_chunk_sparse( - X_data[X_indptr[start]: X_indptr[end]], - X_indices[X_indptr[start]: X_indptr[end]], - X_indptr[start: end], - sample_weight[start: end], - x_squared_norms[start: end], - centers_old, - centers_new, - centers_squared_norms, - weight_in_clusters, - labels[start: end], - update_centers) + # num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() + + for chunk_idx in prange(n_chunks, nogil=True, num_threads=n_jobs): + # remaining samples added to last chunk + if chunk_idx == n_chunks - 1: + n_samples_chunk_eff = n_samples_chunk + n_samples_r + else: + n_samples_chunk_eff = n_samples_chunk + + start = chunk_idx * n_samples_chunk + end = start + n_samples_chunk_eff + + _update_chunk_sparse( + X_data[X_indptr[start]: X_indptr[end]], + X_indices[X_indptr[start]: X_indptr[end]], + X_indptr[start: end], + sample_weight[start: end], + x_squared_norms[start: end], + centers_old, + centers_new, + centers_squared_norms, + weight_in_clusters, + labels[start: end], + update_centers) if update_centers: _relocate_empty_clusters_sparse( X_data, X_indices, X_indptr, sample_weight, - centers_new, weight_in_clusters, labels) + centers_old, centers_new, weight_in_clusters, labels) - _mean_and_center_shift( - centers_old, centers_new, weight_in_clusters, center_shift) + _average_centers(centers_new, weight_in_clusters) + _center_shift(centers_old, centers_new, center_shift) cdef void _update_chunk_sparse(floating[::1] X_data, diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 657a444fd268a..270c2b77a8fd6 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -356,7 +356,7 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', algorithm = "full" if algorithm == "auto": - algorithm = "elkan" + algorithm = "full" if n_clusters == 1 else "elkan" if algorithm == "full": kmeans_single = _kmeans_single_lloyd @@ -366,7 +366,6 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', raise ValueError("Algorithm must be 'auto', 'full' or 'elkan', got" " %s" % str(algorithm)) - n_jobs_ = -1 if n_jobs is None else effective_n_jobs(n_jobs) seeds = random_state.randint(np.iinfo(np.int32).max, size=n_init) # limit number of threads in second level of nested parallelism (i.e. BLAS) @@ -377,7 +376,7 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', labels, inertia, centers, n_iter_ = kmeans_single( X, sample_weight, n_clusters, max_iter=max_iter, init=init, verbose=verbose, tol=tol, x_squared_norms=x_squared_norms, - random_state=seed, n_jobs=n_jobs_) + random_state=seed, n_jobs=effective_n_jobs(n_jobs)) # determine if these results are the best so far if best_inertia is None or inertia < best_inertia: best_labels = labels.copy() @@ -409,6 +408,7 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, # if sp.issparse(X): # raise TypeError("algorithm='elkan' not supported for sparse input X") + n_jobs_ = effective_n_jobs(n_jobs) random_state = check_random_state(random_state) sample_weight = _check_sample_weight(X, sample_weight) @@ -446,7 +446,7 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, for i in range(max_iter): elkan_iter(X, sample_weight, centers_old, centers, weight_in_clusters, center_half_distances, distance_next_center, upper_bounds, - lower_bounds, labels, center_shift, n_jobs) + lower_bounds, labels, center_shift, n_jobs_) # compute new pairwise distances between centers and closest other # center of each center for next iterations @@ -469,7 +469,7 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, # rerun E-step so that predicted labels match cluster centers elkan_iter(X, sample_weight, centers, centers, weight_in_clusters, center_half_distances, distance_next_center, upper_bounds, - lower_bounds, labels, center_shift, n_jobs, + lower_bounds, labels, center_shift, n_jobs_, update_centers=False) inertia = _inertia(X, sample_weight, centers, labels) @@ -479,7 +479,7 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, init='k-means++', verbose=False, x_squared_norms=None, - random_state=None, tol=1e-4, n_jobs=-1): + random_state=None, tol=1e-4, n_jobs=None): """A single run of k-means, assumes preparation completed prior. Parameters @@ -546,8 +546,8 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, n_iter : int Number of iterations run. """ + n_jobs_ = effective_n_jobs(n_jobs) random_state = check_random_state(random_state) - sample_weight = _check_sample_weight(X, sample_weight) # init @@ -573,7 +573,7 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, for i in range(max_iter): lloyd_iter(X, sample_weight, x_squared_norms, centers_old, centers, centers_squared_norms, weight_in_clusters, labels, - center_shift, n_jobs) + center_shift, n_jobs_) if verbose: inertia = _inertia(X, sample_weight, centers_old, labels) @@ -590,14 +590,14 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, # rerun E-step so that predicted labels match cluster centers lloyd_iter(X, sample_weight, x_squared_norms, centers, centers, centers_squared_norms, weight_in_clusters, labels, - center_shift, n_jobs, update_centers=False) + center_shift, n_jobs_, update_centers=False) inertia = _inertia(X, sample_weight, centers, labels) return labels, inertia, centers, i + 1 -def _labels_inertia(X, sample_weight, x_squared_norms, centers): +def _labels_inertia(X, sample_weight, x_squared_norms, centers, n_jobs=1): """E step of the K-means EM algorithm. Compute the labels and the inertia of the given samples and centers. @@ -641,7 +641,7 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers): _labels(X, sample_weight, x_squared_norms, centers, centers, centers_squared_norms, weight_in_clusters, - labels, center_shift, update_centers=False) + labels, center_shift, n_jobs, update_centers=False) inertia = _inertia(X, sample_weight, centers, labels) @@ -1045,7 +1045,8 @@ def predict(self, X, sample_weight=None): x_squared_norms = row_norms(X, squared=True) return _labels_inertia(X, sample_weight, x_squared_norms, - self.cluster_centers_)[0] + self.cluster_centers_, + effective_n_jobs(self.n_jobs))[0] def score(self, X, y=None, sample_weight=None): """Opposite of the value of X on the K-means objective. diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index e194f598b0dcc..7c6f37e073434 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -70,6 +70,31 @@ def test_kmeans_results(representation, algo, dtype): assert kmeans.n_iter_ == expected_n_iter +@pytest.mark.parametrize("array_constr", + [np.array, sp.csr_matrix], + ids=['dense', 'sparse']) +@pytest.mark.parametrize("algo", ['full', 'elkan']) +def test_relocated_clusters(array_constr, algo): + # check that empty clusters are relocated as expected + X = array_constr([[0, 0], [0.5, 0], [0.5, 1], [1, 1]]) + + # second center too far from others points will be empty at first iter + init_centers = np.array([[0.5, 0.5], [3, 3]]) + + expected_labels = [0, 0, 1, 1] + expected_inertia = 0.25 + expected_centers = [[0.25, 0], [0.75, 1]] + expected_n_iter = 3 + + kmeans = KMeans(n_clusters=2, n_init=1, init=init_centers, algorithm=algo) + kmeans.fit(X) + + assert_array_equal(kmeans.labels_, expected_labels) + assert_almost_equal(kmeans.inertia_, expected_inertia) + assert_array_almost_equal(kmeans.cluster_centers_, expected_centers) + assert kmeans.n_iter_ == expected_n_iter + + @pytest.mark.parametrize('distribution', ['normal', 'blobs']) def test_elkan_results(distribution): # check that results are identical between lloyd and elkan algorithms From a48504a89f32ce2e3a11c236d686283dccd3811f Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Tue, 26 Feb 2019 16:00:28 +0100 Subject: [PATCH 062/163] fix relocate empty clusters --- sklearn/cluster/_k_means.pyx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 46a6f45f11573..b115024090a5d 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -150,11 +150,11 @@ cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] int n_features = X.shape[1] floating[::1] distances = ((np.asarray(X) - np.asarray(centers_old)[labels])**2).sum(axis=1) - int[::1] far_from_centers = np.argpartition(distances, -n_empty)[:n_empty-1:-1].astype(np.int32) + int[::1] far_from_centers = np.argpartition(distances, -n_empty)[:-n_empty-1:-1].astype(np.int32) int new_cluster_id, old_cluster_id, far_idx, idx, k floating weight - print() + for idx in range(n_empty): new_cluster_id = empty_clusters[idx] @@ -205,7 +205,7 @@ cdef void _relocate_empty_clusters_sparse(floating[::1] X_data, centers_old[j], centers_squared_norms[j], True) cdef: - int[::1] far_from_centers = np.argpartition(distances, -n_empty)[:n_empty-1:-1].astype(np.int32) + int[::1] far_from_centers = np.argpartition(distances, -n_empty)[:-n_empty-1:-1].astype(np.int32) int new_cluster_id, old_cluster_id, far_idx, idx floating weight From eb09a062c733834e4f58aae914b6c968ea89a4e5 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Tue, 26 Feb 2019 16:17:05 +0100 Subject: [PATCH 063/163] lint... --- sklearn/cluster/tests/test_k_means.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 7c6f37e073434..d3c5d5b9390de 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -120,7 +120,7 @@ def test_elkan_results_sparse(distribution): # check that results are identical between lloyd and elkan algorithms # with sparse input rnd = np.random.RandomState(0) - if distribution is 'normal': + if distribution == 'normal': X = sp.random(100, 100, density=0.1, format='csr', random_state=rnd) X.data = rnd.randn(len(X.data)) else: From 014956d905dd21203ba5437fd83422c05a8f18f9 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Tue, 26 Feb 2019 17:46:46 +0100 Subject: [PATCH 064/163] tst azure openmp --- build_tools/azure/install.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/build_tools/azure/install.sh b/build_tools/azure/install.sh index b4e04e2d41af6..bcda559505f9e 100755 --- a/build_tools/azure/install.sh +++ b/build_tools/azure/install.sh @@ -6,19 +6,19 @@ UNAMESTR=`uname` if [[ "$UNAMESTR" == "Darwin" ]]; then # install OpenMP not present by default on osx - HOMEBREW_NO_AUTO_UPDATE=1 brew install libomp + HOMEBREW_NO_AUTO_UPDATE=1 brew install libiomp # enable OpenMP support for Apple-clang export CC=/usr/bin/clang export CXX=/usr/bin/clang++ export CPPFLAGS="$CPPFLAGS -Xpreprocessor -fopenmp" - export CFLAGS="$CFLAGS -I/usr/local/opt/libomp/include" - export CXXFLAGS="$CXXFLAGS -I/usr/local/opt/libomp/include" - export LDFLAGS="$LDFLAGS -L/usr/local/opt/libomp/lib -lomp" - export DYLD_LIBRARY_PATH=/usr/local/opt/libomp/lib + export CFLAGS="$CFLAGS -I/usr/local/opt/libiomp/include" + export CXXFLAGS="$CXXFLAGS -I/usr/local/opt/libiomp/include" + export LDFLAGS="$LDFLAGS -L/usr/local/opt/libiomp/lib -liomp" + export DYLD_LIBRARY_PATH=/usr/local/opt/libiomp/lib # avoid error due to multiple OpenMP libraries loaded simultaneously - export KMP_DUPLICATE_LIB_OK=TRUE + # export KMP_DUPLICATE_LIB_OK=TRUE fi make_conda() { From ec74a7641b3ffa4331d267642b3c36f4f4a8a757 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Tue, 26 Feb 2019 18:24:12 +0100 Subject: [PATCH 065/163] tst openmp --- build_tools/azure/install.sh | 10 +++++----- build_tools/travis/install.sh | 2 +- sklearn/__init__.py | 4 ++++ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/build_tools/azure/install.sh b/build_tools/azure/install.sh index bcda559505f9e..cdee3611b0383 100755 --- a/build_tools/azure/install.sh +++ b/build_tools/azure/install.sh @@ -6,16 +6,16 @@ UNAMESTR=`uname` if [[ "$UNAMESTR" == "Darwin" ]]; then # install OpenMP not present by default on osx - HOMEBREW_NO_AUTO_UPDATE=1 brew install libiomp + HOMEBREW_NO_AUTO_UPDATE=1 brew install libomp # enable OpenMP support for Apple-clang export CC=/usr/bin/clang export CXX=/usr/bin/clang++ export CPPFLAGS="$CPPFLAGS -Xpreprocessor -fopenmp" - export CFLAGS="$CFLAGS -I/usr/local/opt/libiomp/include" - export CXXFLAGS="$CXXFLAGS -I/usr/local/opt/libiomp/include" - export LDFLAGS="$LDFLAGS -L/usr/local/opt/libiomp/lib -liomp" - export DYLD_LIBRARY_PATH=/usr/local/opt/libiomp/lib + export CFLAGS="$CFLAGS -I/usr/local/opt/libomp/include" + export CXXFLAGS="$CXXFLAGS -I/usr/local/opt/libomp/include" + export LDFLAGS="$LDFLAGS -L/usr/local/opt/libomp/lib -lomp" + export DYLD_LIBRARY_PATH=/usr/local/opt/libomp/lib # avoid error due to multiple OpenMP libraries loaded simultaneously # export KMP_DUPLICATE_LIB_OK=TRUE diff --git a/build_tools/travis/install.sh b/build_tools/travis/install.sh index d88af3ed81d4f..804f761cabc71 100755 --- a/build_tools/travis/install.sh +++ b/build_tools/travis/install.sh @@ -40,7 +40,7 @@ then export DYLD_LIBRARY_PATH=/usr/local/opt/libomp/lib # avoid error due to multiple OpenMP libraries loaded simultaneously - export KMP_DUPLICATE_LIB_OK=TRUE + # export KMP_DUPLICATE_LIB_OK=TRUE fi make_conda() { diff --git a/sklearn/__init__.py b/sklearn/__init__.py index aafc8a34b2a13..233aa16c52141 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -16,6 +16,7 @@ import re import warnings import logging +import os from ._config import get_config, set_config, config_context @@ -47,6 +48,9 @@ __version__ = '0.21.dev0' +os.environ.setdefault("KMP_DUPLICATE_LIB_OK", True) + + try: # This variable is injected in the __builtins__ by the build # process. It is used to enable importing subpackages of sklearn when From 5485c96c785f06afc730b68cb31bc62399aedb95 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Tue, 26 Feb 2019 18:31:28 +0100 Subject: [PATCH 066/163] same --- sklearn/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/__init__.py b/sklearn/__init__.py index 233aa16c52141..89c7f0e8614ce 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -48,7 +48,7 @@ __version__ = '0.21.dev0' -os.environ.setdefault("KMP_DUPLICATE_LIB_OK", True) +os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "True") try: From 8a07a32b57476c993aa303ae8968e7c5d9c261cb Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 28 Feb 2019 16:08:34 +0100 Subject: [PATCH 067/163] adress comments & improve docstrings --- sklearn/cluster/_k_means.pyx | 15 +- sklearn/cluster/_k_means_elkan.pyx | 17 +- sklearn/cluster/_k_means_lloyd.pyx | 20 +-- sklearn/cluster/k_means_.py | 270 +++++++++++++++++------------ 4 files changed, 181 insertions(+), 141 deletions(-) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index b115024090a5d..3459942916abb 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -35,6 +35,7 @@ cdef floating _euclidean_dense_dense(floating* a, int rem = n_features % 4 floating result = 0 + # We manually unroll the loop for better cache optimization. for i in range(n): result += ((a[0] - b[0]) * (a[0] - b[0]) +(a[1] - b[1]) * (a[1] - b[1]) @@ -45,9 +46,7 @@ cdef floating _euclidean_dense_dense(floating* a, for i in range(rem): result += (a[i] - b[i]) * (a[i] - b[i]) - if not squared: result = sqrt(result) - - return result + return result if squared else sqrt(result) cdef floating _euclidean_sparse_dense(floating[::1] a_data, @@ -69,9 +68,8 @@ cdef floating _euclidean_sparse_dense(floating[::1] a_data, result += b_squared_norm if result < 0: result = 0.0 - if not squared: result = sqrt(result) - return result + return result is squared else sqrt(result) cpdef floating _inertia_dense(np.ndarray[floating, ndim=2, mode='c'] X, @@ -79,7 +77,7 @@ cpdef floating _inertia_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[:, ::1] centers, int[::1] labels): """Compute inertia for dense input data - + Sum of squared distance between each sample and its assigned center. """ cdef: @@ -118,7 +116,7 @@ cpdef floating _inertia_sparse(X, floating sq_dist = 0.0 floating inertia = 0.0 - + floating[::1] centers_squared_norms = row_norms(centers, squared=True) for i in range(n_samples): @@ -170,7 +168,6 @@ cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] weight_in_clusters[new_cluster_id] = weight weight_in_clusters[old_cluster_id] -= weight - print('ok') cdef void _relocate_empty_clusters_sparse(floating[::1] X_data, int[::1] X_indices, @@ -218,7 +215,7 @@ cdef void _relocate_empty_clusters_sparse(floating[::1] X_data, weight = sample_weight[far_idx] old_cluster_id = labels[far_idx] - + for k in range(X_indptr[far_idx], X_indptr[far_idx + 1]): centers_new[old_cluster_id, X_indices[k]] -= X_data[k] * weight centers_new[new_cluster_id, X_indices[k]] = X_data[k] * weight diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 9f573cfc1d873..e3dca5f8fd2ea 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -254,14 +254,13 @@ shape (n_clusters, n_clusters) # necessary to get parallelism. Chunk size chosed to be same as lloyd's int n_samples_chunk = 256 if n_samples > 256 else n_samples int n_chunks = n_samples // n_samples_chunk - int n_samples_r = n_samples % n_samples_chunk + int n_samples_rem = n_samples % n_samples_chunk int chunk_idx, n_samples_chunk_eff int start, end - # int num_threads int i, j, k - # If n_samples < 256 there's still one chunk of size n_samples_r + # If n_samples < 256 there's still one chunk of size n_samples_rem if n_chunks == 0: n_chunks = 1 n_samples_chunk = 0 @@ -275,7 +274,7 @@ shape (n_clusters, n_clusters) for chunk_idx in prange(n_chunks, nogil=True, num_threads=n_jobs): # remaining samples added to last chunk if chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_chunk + n_samples_r + n_samples_chunk_eff = n_samples_chunk + n_samples_rem else: n_samples_chunk_eff = n_samples_chunk @@ -469,16 +468,15 @@ shape (n_clusters, n_clusters) # necessary to get parallelism. Chunk size chosed to be same as lloyd's int n_samples_chunk = 256 if n_samples > 256 else n_samples int n_chunks = n_samples // n_samples_chunk - int n_samples_r = n_samples % n_samples_chunk + int n_samples_rem = n_samples % n_samples_chunk int chunk_idx, n_samples_chunk_eff int start, end - # int num_threads int i, j, k floating[::1] centers_squared_norms = row_norms(centers_new, squared=True) - # If n_samples < 256 there's still one chunk of size n_samples_r + # If n_samples < 256 there's still one chunk of size n_samples_rem if n_chunks == 0: n_chunks = 1 n_samples_chunk = 0 @@ -489,13 +487,10 @@ shape (n_clusters, n_clusters) memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) - # set number of threads to be used by openmp - # num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() - for chunk_idx in prange(n_chunks, nogil=True, num_threads=n_jobs): # remaining samples added to last chunk if chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_chunk + n_samples_r + n_samples_chunk_eff = n_samples_chunk + n_samples_rem else: n_samples_chunk_eff = n_samples_chunk diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index 7c226224e4014..79278823b6b06 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -91,14 +91,13 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, # optimal in all situations. int n_samples_chunk = 256 if n_samples > 256 else n_samples int n_chunks = n_samples // n_samples_chunk - int n_samples_r = n_samples % n_samples_chunk + int n_samples_rem = n_samples % n_samples_chunk int chunk_idx, n_samples_chunk_eff int start, end - # int num_threads int j, k - # If n_samples < 256 there's still one chunk of size n_samples_r + # If n_samples < 256 there's still one chunk of size n_samples_rem if n_chunks == 0: n_chunks = 1 n_samples_chunk = 0 @@ -111,13 +110,10 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) - # set number of threads to be used by openmp - # num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() - for chunk_idx in prange(n_chunks, nogil=True, num_threads=n_jobs): # remaining samples added to last chunk if chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_chunk + n_samples_r + n_samples_chunk_eff = n_samples_chunk + n_samples_rem else: n_samples_chunk_eff = n_samples_chunk @@ -270,10 +266,9 @@ cpdef void _lloyd_iter_chunked_sparse(X, # However, splitting in chunks is necessary to get parallelism. int n_samples_chunk = 256 if n_samples > 256 else n_samples int n_chunks = n_samples // n_samples_chunk - int n_samples_r = n_samples % n_samples_chunk + int n_samples_rem = n_samples % n_samples_chunk int chunk_idx, n_samples_chunk_eff = 0 int start = 0, end = 0 - # int num_threads int j, k floating alpha @@ -282,7 +277,7 @@ cpdef void _lloyd_iter_chunked_sparse(X, int[::1] X_indices = X.indices int[::1] X_indptr = X.indptr - # If n_samples < 256 there's still one chunk of size n_samples_r + # If n_samples < 256 there's still one chunk of size n_samples_rem if n_chunks == 0: n_chunks = 1 n_samples_chunk = 0 @@ -295,13 +290,10 @@ cpdef void _lloyd_iter_chunked_sparse(X, memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) - # set number of threads to be used by openmp - # num_threads = n_jobs if n_jobs != -1 else openmp.omp_get_max_threads() - for chunk_idx in prange(n_chunks, nogil=True, num_threads=n_jobs): # remaining samples added to last chunk if chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_chunk + n_samples_r + n_samples_chunk_eff = n_samples_chunk + n_samples_rem else: n_samples_chunk_eff = n_samples_chunk diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 270c2b77a8fd6..53a7207157831 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -28,7 +28,6 @@ from ..utils.validation import check_is_fitted from ..utils.validation import FLOAT_DTYPES from ..utils._clibs import thread_limits_context -from ..utils._joblib import effective_n_jobs from ..exceptions import ConvergenceWarning from ._k_means import _inertia_dense from ._k_means import _inertia_sparse @@ -65,7 +64,7 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None): randomness deterministic. See :term:`Glossary `. - n_local_trials : integer, optional + n_local_trials : integer or None (default=None) The number of seeding trials for each center (except the first), of which the one reducing inertia the most is greedily chosen. Set to None to make the number of trials depend logarithmically @@ -205,12 +204,12 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', The number of clusters to form as well as the number of centroids to generate. - sample_weight : array-like, shape (n_samples,), optional + sample_weight : array-like, shape (n_samples,), optional (default=None) The weights for each observation in X. If None, all observations - are assigned equal weight (default: None) + are assigned equal weight - init : {'k-means++', 'random', or ndarray, or a callable}, optional - Method for initialization, default to 'k-means++': + init : {'k-means++', 'random', ndarray, callable}, (default='k-means++') + Method for initialization: 'k-means++' : selects initial cluster centers for k-mean clustering in a smart way to speed up convergence. See section @@ -239,47 +238,45 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', 'precompute_distances' was deprecated in version 0.21 and will be removed in 0.23. - n_init : int, optional, default: 10 + n_init : int, (default=10) Number of time the k-means algorithm will be run with different centroid seeds. The final results will be the best output of n_init consecutive runs in terms of inertia. - max_iter : int, optional, default 300 + max_iter : int, (default=300) Maximum number of iterations of the k-means algorithm to run. - verbose : boolean, optional + verbose : boolean, optional (default=False) Verbosity mode. - tol : float, optional + tol : float (default=1e-4) The relative increment in the results before declaring convergence. - random_state : int, RandomState instance or None (default) + random_state : int, RandomState instance or None (default=None) Determines random number generation for centroid initialization. Use an int to make the randomness deterministic. See :term:`Glossary `. - copy_x : boolean, optional + copy_x : boolean, optional (default=True) When pre-computing distances it is more numerically accurate to center - the data first. If copy_x is True (default), then the original data is + the data first. If copy_x is True (default), then the original data is not modified. If False, the original data is modified, and put back before the function returns, but small numerical differences may be introduced by subtracting and then adding the data mean. Note that if the original data is not C-contiguous, a copy will be made even if - copy_x is False. + copy_x is False. If the original data is sparse, but not in CSR format, + a copy will be made even if copy_x is False. n_jobs : int or None, optional (default=None) - The number of jobs to use for the computation. This works by computing - each of the n_init runs in parallel. + The number of jobs to use for the computation. - ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. - ``-1`` means using all processors. See :term:`Glossary ` - for more details. + ``None`` or ``-1`` means using all processors. See + :term:`Glossary ` for more details. - algorithm : "auto", "full" or "elkan", default="auto" + algorithm : {"auto", "full", "elkan"} (default="auto") K-means algorithm to use. The classical EM-style algorithm is "full". - The "elkan" variation is more efficient by using the triangle - inequality, but currently doesn't support sparse data. "auto" chooses - "elkan" for dense data and "full" for sparse data. + The "elkan" variation is more efficient, on well structured data, by + using the triangle inequality. "auto" chooses "elkan". return_n_iter : bool, optional Whether or not to return the number of iterations. @@ -366,6 +363,7 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', raise ValueError("Algorithm must be 'auto', 'full' or 'elkan', got" " %s" % str(algorithm)) + # seeds for the initializations of the kmeans runs. seeds = random_state.randint(np.iinfo(np.int32).max, size=n_init) # limit number of threads in second level of nested parallelism (i.e. BLAS) @@ -376,7 +374,7 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', labels, inertia, centers, n_iter_ = kmeans_single( X, sample_weight, n_clusters, max_iter=max_iter, init=init, verbose=verbose, tol=tol, x_squared_norms=x_squared_norms, - random_state=seed, n_jobs=effective_n_jobs(n_jobs)) + random_state=seed, n_jobs=n_jobs) # determine if these results are the best so far if best_inertia is None or inertia < best_inertia: best_labels = labels.copy() @@ -405,10 +403,73 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, init='k-means++', verbose=False, x_squared_norms=None, random_state=None, tol=1e-4, n_jobs=None): - # if sp.issparse(X): - # raise TypeError("algorithm='elkan' not supported for sparse input X") + """A single run of k-means lloyd, assumes preparation completed prior. + + Parameters + ---------- + X : array-like or CSR matrix, shape (n_samples, n_features) + The observations to cluster. + + sample_weight : array-like, shape (n_samples,) + The weights for each observation in X. + + n_clusters : int + The number of clusters to form as well as the number of + centroids to generate. + + max_iter : int (default=300) + Maximum number of iterations of the k-means algorithm to run. + + init : {'k-means++', 'random', ndarray, callable} (default='k-means++') + Method for initialization, default to 'k-means++': + + 'k-means++' : selects initial cluster centers for k-mean + clustering in a smart way to speed up convergence. See section + Notes in k_init for more details. + + 'random': choose k observations (rows) at random from data for + the initial centroids. + + If an ndarray is passed, it should be of shape (k, p) and gives + the initial centers. + + If a callable is passed, it should take arguments X, k and + and a random state and return an initialization. + + verbose : boolean, optional (default=False) + Verbosity mode + + x_squared_norms : array-like or None (default=None) + Precomputed x_squared_norms. + + random_state : int, RandomState instance or None (default=None) + Determines random number generation for centroid initialization. Use + an int to make the randomness deterministic. + See :term:`Glossary `. + + tol : float (default=1e-4) + The relative increment in the results before declaring convergence. + + n_jobs : int or None (default=None) + The number of threads to be used. If -1 or None, will use as many as + possible. + + Returns + ------- + centroid : float ndarray, shape (n_clusters, n_features) + Centroids found at the last iteration of k-means. + + label : integer ndarray, shape (n_samples,) + label[i] is the code or index of the centroid the + i'th observation is closest to. - n_jobs_ = effective_n_jobs(n_jobs) + inertia : float + The final value of the inertia criterion (sum of squared distances to + the closest centroid for all observations in the training set). + + n_iter : int + Number of iterations run. + """ random_state = check_random_state(random_state) sample_weight = _check_sample_weight(X, sample_weight) @@ -446,7 +507,7 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, for i in range(max_iter): elkan_iter(X, sample_weight, centers_old, centers, weight_in_clusters, center_half_distances, distance_next_center, upper_bounds, - lower_bounds, labels, center_shift, n_jobs_) + lower_bounds, labels, center_shift, n_jobs) # compute new pairwise distances between centers and closest other # center of each center for next iterations @@ -469,7 +530,7 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, # rerun E-step so that predicted labels match cluster centers elkan_iter(X, sample_weight, centers, centers, weight_in_clusters, center_half_distances, distance_next_center, upper_bounds, - lower_bounds, labels, center_shift, n_jobs_, + lower_bounds, labels, center_shift, n_jobs, update_centers=False) inertia = _inertia(X, sample_weight, centers, labels) @@ -480,24 +541,24 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, init='k-means++', verbose=False, x_squared_norms=None, random_state=None, tol=1e-4, n_jobs=None): - """A single run of k-means, assumes preparation completed prior. + """A single run of k-means lloyd, assumes preparation completed prior. Parameters ---------- - X : array-like of floats, shape (n_samples, n_features) + X : array-like or CSR matrix, shape (n_samples, n_features) The observations to cluster. + sample_weight : array-like, shape (n_samples,) + The weights for each observation in X. + n_clusters : int The number of clusters to form as well as the number of centroids to generate. - sample_weight : array-like, shape (n_samples,) - The weights for each observation in X. - - max_iter : int, optional, default 300 + max_iter : int (default=300) Maximum number of iterations of the k-means algorithm to run. - init : {'k-means++', 'random', or ndarray, or a callable}, optional + init : {'k-means++', 'random', ndarray, callable} (default='k-means++') Method for initialization, default to 'k-means++': 'k-means++' : selects initial cluster centers for k-mean @@ -513,29 +574,30 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, If a callable is passed, it should take arguments X, k and and a random state and return an initialization. - tol : float, optional - The relative increment in the results before declaring convergence. - - verbose : boolean, optional + verbose : boolean, optional (default=False) Verbosity mode - x_squared_norms : array + x_squared_norms : array-like or None (default=None) Precomputed x_squared_norms. - random_state : int, RandomState instance or None (default) + random_state : int, RandomState instance or None (default=None) Determines random number generation for centroid initialization. Use an int to make the randomness deterministic. See :term:`Glossary `. - n_jobs : int - The number of threads to be used. If -1, will use as many as possible. + tol : float (default=1e-4) + The relative increment in the results before declaring convergence. + + n_jobs : int or None (default=None) + The number of threads to be used. If -1 or None, will use as many as + possible. Returns ------- - centroid : float ndarray with shape (k, n_features) + centroid : float ndarray, shape (n_clusters, n_features) Centroids found at the last iteration of k-means. - label : integer ndarray with shape (n_samples,) + label : integer ndarray, shape (n_samples,) label[i] is the code or index of the centroid the i'th observation is closest to. @@ -546,7 +608,6 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, n_iter : int Number of iterations run. """ - n_jobs_ = effective_n_jobs(n_jobs) random_state = check_random_state(random_state) sample_weight = _check_sample_weight(X, sample_weight) @@ -573,7 +634,7 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, for i in range(max_iter): lloyd_iter(X, sample_weight, x_squared_norms, centers_old, centers, centers_squared_norms, weight_in_clusters, labels, - center_shift, n_jobs_) + center_shift, n_jobs) if verbose: inertia = _inertia(X, sample_weight, centers_old, labels) @@ -590,7 +651,7 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, # rerun E-step so that predicted labels match cluster centers lloyd_iter(X, sample_weight, x_squared_norms, centers, centers, centers_squared_norms, weight_in_clusters, labels, - center_shift, n_jobs_, update_centers=False) + center_shift, n_jobs, update_centers=False) inertia = _inertia(X, sample_weight, centers, labels) @@ -604,7 +665,7 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers, n_jobs=1): Parameters ---------- - X : float array-like or CSR sparse matrix, shape (n_samples, n_features) + X : array-like or CSR sparse matrix, shape (n_samples, n_features) The input samples to assign to the labels. sample_weight : array-like, shape (n_samples,) @@ -614,7 +675,7 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers, n_jobs=1): Precomputed squared euclidean norm of each data point, to speed up computations. - centers : float array, shape (n_clusters, n_features) + centers : array, shape (n_clusters, n_features) The cluster centers. Returns @@ -656,23 +717,24 @@ def _init_centroids(X, k, init, random_state=None, x_squared_norms=None, ---------- X : array, shape (n_samples, n_features) + The input samples. k : int - number of centroids + number of centroids. - init : {'k-means++', 'random' or ndarray or callable} optional - Method for initialization + init : {'k-means++', 'random', ndarray, callable} + Method for initialization. - random_state : int, RandomState instance or None (default) + random_state : int, RandomState instance or None (default=None) Determines random number generation for centroid initialization. Use an int to make the randomness deterministic. See :term:`Glossary `. - x_squared_norms : array, shape (n_samples,), optional + x_squared_norms : array, shape (n_samples,) (default=None) Squared euclidean norm of each data point. Pass it if you have it at hands already to avoid it being recomputed here. Default: None - init_size : int, optional + init_size : int (default=None) Number of samples to randomly sample for speeding up the initialization (sometimes at the expense of accuracy): the only algorithm is initialized by running a batch KMeans on a @@ -736,12 +798,12 @@ class KMeans(BaseEstimator, ClusterMixin, TransformerMixin): Parameters ---------- - n_clusters : int, optional, default: 8 + n_clusters : int (default=8) The number of clusters to form as well as the number of centroids to generate. - init : {'k-means++', 'random' or an ndarray} - Method for initialization, defaults to 'k-means++': + init : {'k-means++', 'random', ndarray, callable} (default='k-means++') + Method for initialization: 'k-means++' : selects initial cluster centers for k-mean clustering in a smart way to speed up convergence. See section @@ -753,19 +815,22 @@ class KMeans(BaseEstimator, ClusterMixin, TransformerMixin): If an ndarray is passed, it should be of shape (n_clusters, n_features) and gives the initial centers. - n_init : int, default: 10 + If a callable is passed, it should take arguments X, k and + and a random state and return an initialization. + + n_init : int (default=10) Number of time the k-means algorithm will be run with different centroid seeds. The final results will be the best output of n_init consecutive runs in terms of inertia. - max_iter : int, default: 300 + max_iter : int (default=300) Maximum number of iterations of the k-means algorithm for a single run. - tol : float, default: 1e-4 + tol : float (default=1e-4) Relative tolerance with regards to inertia to declare convergence - precompute_distances : {'auto', True, False} + precompute_distances : {'auto', True, False} (default='auto') Precompute distances (faster but takes more memory). 'auto' : do not precompute distances if n_samples * n_clusters > 12 @@ -779,36 +844,35 @@ class KMeans(BaseEstimator, ClusterMixin, TransformerMixin): 'precompute_distances' was deprecated in version 0.21 and will be removed in 0.23. - verbose : int, default 0 + verbose : int, optional (default=0) Verbosity mode. - random_state : int, RandomState instance or None (default) + random_state : int, RandomState instance or None (default=None) Determines random number generation for centroid initialization. Use an int to make the randomness deterministic. See :term:`Glossary `. - copy_x : boolean, optional + copy_x : boolean, optional (default=True) When pre-computing distances it is more numerically accurate to center - the data first. If copy_x is True (default), then the original data is + the data first. If copy_x is True (default), then the original data is not modified. If False, the original data is modified, and put back before the function returns, but small numerical differences may be introduced by subtracting and then adding the data mean. Note that if the original data is not C-contiguous, a copy will be made even if - copy_x is False. + copy_x is False. If the original data is sparse, but not in CSR format, + a copy will be made even if copy_x is False. n_jobs : int or None, optional (default=None) The number of jobs to use for the computation. This works by computing each of the n_init runs in parallel. - ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. - ``-1`` means using all processors. See :term:`Glossary ` - for more details. + ``None`` or ``-1`` means using all processors. See + :term:`Glossary ` for more details. - algorithm : "auto", "full" or "elkan", default="auto" + algorithm : {"auto", "full", "elkan"} (default="auto") K-means algorithm to use. The classical EM-style algorithm is "full". - The "elkan" variation is more efficient by using the triangle - inequality, but currently doesn't support sparse data. "auto" chooses - "elkan" for dense data and "full" for sparse data. + The "elkan" variation is more efficient, on well structured data, by + using the triangle inequality. "auto" chooses "elkan". Attributes ---------- @@ -908,7 +972,7 @@ def fit(self, X, y=None, sample_weight=None): Parameters ---------- - X : array-like or sparse matrix, shape=(n_samples, n_features) + X : {array-like, sparse matrix}, shape=(n_samples, n_features) Training instances to cluster. It must be noted that the data will be converted to C ordering, which will cause a memory copy if the given data is not C-contiguous. @@ -916,27 +980,20 @@ def fit(self, X, y=None, sample_weight=None): y : Ignored not used, present here for API consistency by convention. - sample_weight : array-like, shape (n_samples,), optional + sample_weight : array-like, shape (n_samples,), optional (default=None) The weights for each observation in X. If None, all observations - are assigned equal weight (default: None) + are assigned equal weight. """ - if self.precompute_distances != 'not-used': - warnings.warn("'precompute_distances' was deprecated in version" - "0.21 and will be removed in 0.23.", - DeprecationWarning) - - random_state = check_random_state(self.random_state) - self.cluster_centers_, self.labels_, self.inertia_, self.n_iter_ = \ k_means( X, n_clusters=self.n_clusters, sample_weight=sample_weight, init=self.init, n_init=self.n_init, max_iter=self.max_iter, verbose=self.verbose, precompute_distances=self.precompute_distances, - tol=self.tol, random_state=random_state, copy_x=self.copy_x, - n_jobs=self.n_jobs, algorithm=self.algorithm, - return_n_iter=True) + tol=self.tol, random_state=self.random_state, + copy_x=self.copy_x, n_jobs=self.n_jobs, + algorithm=self.algorithm, return_n_iter=True) return self def fit_predict(self, X, y=None, sample_weight=None): @@ -947,19 +1004,19 @@ def fit_predict(self, X, y=None, sample_weight=None): Parameters ---------- - X : {array-like, sparse matrix}, shape = [n_samples, n_features] + X : {array-like, sparse matrix}, shape = (n_samples, n_features) New data to transform. y : Ignored not used, present here for API consistency by convention. - sample_weight : array-like, shape (n_samples,), optional + sample_weight : array-like, shape (n_samples,), optional (default=None) The weights for each observation in X. If None, all observations - are assigned equal weight (default: None) + are assigned equal weight. Returns ------- - labels : array, shape [n_samples,] + labels : array, shape (n_samples,) Index of the cluster each sample belongs to. """ return self.fit(X, sample_weight=sample_weight).labels_ @@ -971,19 +1028,19 @@ def fit_transform(self, X, y=None, sample_weight=None): Parameters ---------- - X : {array-like, sparse matrix}, shape = [n_samples, n_features] + X : {array-like, sparse matrix}, shape = (n_samples, n_features) New data to transform. y : Ignored not used, present here for API consistency by convention. - sample_weight : array-like, shape (n_samples,), optional + sample_weight : array-like, shape (n_samples,), optional (default=None) The weights for each observation in X. If None, all observations - are assigned equal weight (default: None) + are assigned equal weight. Returns ------- - X_new : array, shape [n_samples, k] + X_new : array, shape (n_samples, n_clusters) X transformed in the new space. """ # Currently, this just skips a copy of the data if it is not in @@ -1001,12 +1058,12 @@ def transform(self, X): Parameters ---------- - X : {array-like, sparse matrix}, shape = [n_samples, n_features] + X : {array-like, sparse matrix}, shape = (n_samples, n_features) New data to transform. Returns ------- - X_new : array, shape [n_samples, k] + X_new : array, shape (n_samples, n_clusters) X transformed in the new space. """ check_is_fitted(self, 'cluster_centers_') @@ -1027,16 +1084,16 @@ def predict(self, X, sample_weight=None): Parameters ---------- - X : {array-like, sparse matrix}, shape = [n_samples, n_features] + X : {array-like, sparse matrix}, shape = (n_samples, n_features) New data to predict. - sample_weight : array-like, shape (n_samples,), optional + sample_weight : array-like, shape (n_samples,), optional (default=None) The weights for each observation in X. If None, all observations - are assigned equal weight (default: None) + are assigned equal weight. Returns ------- - labels : array, shape [n_samples,] + labels : array, shape (n_samples,) Index of the cluster each sample belongs to. """ check_is_fitted(self, 'cluster_centers_') @@ -1045,15 +1102,14 @@ def predict(self, X, sample_weight=None): x_squared_norms = row_norms(X, squared=True) return _labels_inertia(X, sample_weight, x_squared_norms, - self.cluster_centers_, - effective_n_jobs(self.n_jobs))[0] + self.cluster_centers_, self.n_jobs)[0] def score(self, X, y=None, sample_weight=None): """Opposite of the value of X on the K-means objective. Parameters ---------- - X : {array-like, sparse matrix}, shape = [n_samples, n_features] + X : {array-like, sparse matrix}, shape = (n_samples, n_features) New data. y : Ignored @@ -1061,7 +1117,7 @@ def score(self, X, y=None, sample_weight=None): sample_weight : array-like, shape (n_samples,), optional The weights for each observation in X. If None, all observations - are assigned equal weight (default: None) + are assigned equal weight. Returns ------- From 40de5b33106d363a83bc3ee0c5d15620c2f55178 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 14 Mar 2019 14:26:19 +0100 Subject: [PATCH 068/163] fix --- sklearn/cluster/_k_means.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 3459942916abb..b4599413e01d1 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -69,7 +69,7 @@ cdef floating _euclidean_sparse_dense(floating[::1] a_data, if result < 0: result = 0.0 - return result is squared else sqrt(result) + return result if squared else sqrt(result) cpdef floating _inertia_dense(np.ndarray[floating, ndim=2, mode='c'] X, From 0aaee58fcc5480f89828da2782caa26aceff7d4f Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 24 Jun 2019 11:51:28 +0200 Subject: [PATCH 069/163] revert last changes: bad scalabilty --- sklearn/cluster/_k_means_lloyd.pyx | 187 ++++++++++++++++------------- sklearn/cluster/k_means_.py | 6 + 2 files changed, 108 insertions(+), 85 deletions(-) diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index 79278823b6b06..8abcb817cd7a5 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -9,7 +9,7 @@ cimport openmp from cython cimport floating from cython.parallel import prange, parallel from libc.math cimport sqrt -from libc.stdlib cimport malloc, free +from libc.stdlib cimport malloc, calloc, free from libc.string cimport memset, memcpy from libc.float cimport DBL_MAX, FLT_MAX @@ -97,43 +97,61 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, int j, k - # If n_samples < 256 there's still one chunk of size n_samples_rem - if n_chunks == 0: - n_chunks = 1 - n_samples_chunk = 0 + floating *centers_new_chunk + floating *weight_in_clusters_chunk + floating *pairwise_distances_chunk + + # count remainder chunk in total number of chunks + n_chunks += n_samples != n_chunks * n_samples_chunk # re-initialize all arrays at each iteration centers_squared_norms = row_norms(centers_new, squared=True) - if update_centers: memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) - for chunk_idx in prange(n_chunks, nogil=True, num_threads=n_jobs): - # remaining samples added to last chunk - if chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_chunk + n_samples_rem - else: - n_samples_chunk_eff = n_samples_chunk - - start = chunk_idx * n_samples_chunk - end = start + n_samples_chunk_eff - - _update_chunk_dense( - &X[start, 0], - sample_weight[start: end], - x_squared_norms[start: end], - centers_old, - centers_new, - centers_squared_norms, - weight_in_clusters, - labels[start: end], - update_centers) + with nogil, parallel(num_threads=n_jobs): + # thread local buffers + centers_new_chunk = calloc(n_clusters * n_features, sizeof(floating)) + weight_in_clusters_chunk = calloc(n_clusters, sizeof(floating)) + pairwise_distances_chunk = malloc(n_samples_chunk * n_clusters * sizeof(floating)) + + for chunk_idx in prange(n_chunks): + start = chunk_idx * n_samples_chunk + if chunk_idx == n_chunks - 1 and n_samples_rem > 0: + end = start + n_samples_rem + else: + end = start + n_samples_chunk + + _update_chunk_dense( + &X[start, 0], + sample_weight[start: end], + x_squared_norms[start: end], + centers_old, + centers_squared_norms, + labels[start: end], + centers_new_chunk, + weight_in_clusters_chunk, + pairwise_distances_chunk, + update_centers) + + # reduction from local buffers. The gil is necessary for that to avoid + # race conditions. + if update_centers: + with gil: + for j in range(n_clusters): + weight_in_clusters[j] += weight_in_clusters_chunk[j] + for k in range(n_features): + centers_new[j, k] += centers_new_chunk[j * n_features + k] + + free(centers_new_chunk) + free(weight_in_clusters_chunk) + free(pairwise_distances_chunk) if update_centers: _relocate_empty_clusters_dense(X, sample_weight, centers_old, - centers_new, weight_in_clusters, labels) + centers_new, weight_in_clusters, labels) _average_centers(centers_new, weight_in_clusters) _center_shift(centers_old, centers_new, center_shift) @@ -143,10 +161,11 @@ cdef void _update_chunk_dense(floating *X, floating[::1] sample_weight, floating[::1] x_squared_norms, floating[:, ::1] centers_old, - floating[:, ::1] centers_new, floating[::1] centers_squared_norms, - floating[::1] weight_in_clusters, int[::1] labels, + floating *centers_new, + floating *weight_in_clusters, + floating *pairwise_distances, bint update_centers) nogil: """K-means combined EM step for one dense data chunk. @@ -161,43 +180,34 @@ cdef void _update_chunk_dense(floating *X, floating sq_dist, min_sq_dist int i, j, k, label - floating *pairwise_distances_ptr = malloc(n_samples * n_clusters * sizeof(floating)) - floating[:, ::1] pairwise_distances - - with gil: - pairwise_distances = pairwise_distances_ptr - # Instead of computing the full pairwise squared distances matrix, # ||X - C||² = ||X||² - 2 X.C^T + ||C||², we only need to store # the - 2 X.C^T + ||C||² term since the argmin for a given sample only # depends on the centers. for i in range(n_samples): for j in range(n_clusters): - pairwise_distances[i, j] = centers_squared_norms[j] + pairwise_distances[i * n_clusters + j] = centers_squared_norms[j] _gemm(RowMajor, NoTrans, Trans, n_samples, n_clusters, n_features, -2.0, X, n_features, ¢ers_old[0, 0], n_features, - 1.0, pairwise_distances_ptr, n_clusters) + 1.0, pairwise_distances, n_clusters) for i in range(n_samples): - min_sq_dist = pairwise_distances[i, 0] + min_sq_dist = pairwise_distances[i * n_clusters] label = 0 for j in range(1, n_clusters): - sq_dist = pairwise_distances[i, j] + sq_dist = pairwise_distances[i * n_clusters + j] if sq_dist < min_sq_dist: min_sq_dist = sq_dist label = j labels[i] = label - free(pairwise_distances_ptr) - + # XXX try inside prev loop if update_centers: - # The gil is necessary for that to avoid race conditions. - with gil: - for i in range(n_samples): - weight_in_clusters[labels[i]] += sample_weight[i] - for k in range(n_features): - centers_new[labels[i], k] += X[i * n_features + k] * sample_weight[i] + for i in range(n_samples): + weight_in_clusters[labels[i]] += sample_weight[i] + for k in range(n_features): + centers_new[labels[i] * n_features + k] += X[i * n_features + k] * sample_weight[i] cpdef void _lloyd_iter_chunked_sparse(X, @@ -271,48 +281,58 @@ cpdef void _lloyd_iter_chunked_sparse(X, int start = 0, end = 0 int j, k - floating alpha floating[::1] X_data = X.data int[::1] X_indices = X.indices int[::1] X_indptr = X.indptr - # If n_samples < 256 there's still one chunk of size n_samples_rem - if n_chunks == 0: - n_chunks = 1 - n_samples_chunk = 0 + # count remainder chunk in total number of chunks + n_chunks += n_samples != n_chunks * n_samples_chunk # re-initialize all arrays at each iteration centers_squared_norms = row_norms(centers_new, squared=True) - if update_centers: memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) - for chunk_idx in prange(n_chunks, nogil=True, num_threads=n_jobs): - # remaining samples added to last chunk - if chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_chunk + n_samples_rem - else: - n_samples_chunk_eff = n_samples_chunk - - start = chunk_idx * n_samples_chunk - end = start + n_samples_chunk_eff - - _update_chunk_sparse( - X_data[X_indptr[start]: X_indptr[end]], - X_indices[X_indptr[start]: X_indptr[end]], - X_indptr[start: end], - sample_weight[start: end], - x_squared_norms[start: end], - centers_old, - centers_new, - centers_squared_norms, - weight_in_clusters, - labels[start: end], - update_centers) - + with nogil, parallel(num_threads=n_jobs): + # thread local buffers + centers_new_chunk = calloc(n_clusters * n_features, sizeof(floating)) + weight_in_clusters_chunk = calloc(n_clusters, sizeof(floating)) + + for chunk_idx in prange(n_chunks): + start = chunk_idx * n_samples_chunk + if chunk_idx == n_chunks - 1 and n_samples_rem > 0: + end = start + n_samples_rem + else: + end = start + n_samples_chunk + + _update_chunk_sparse( + X_data[X_indptr[start]: X_indptr[end]], + X_indices[X_indptr[start]: X_indptr[end]], + X_indptr[start: end], + sample_weight[start: end], + x_squared_norms[start: end], + centers_old, + centers_squared_norms, + labels[start: end], + centers_new_chunk, + weight_in_clusters_chunk, + update_centers) + + # reduction from local buffers. The gil is necessary for that to avoid + # race conditions. + if update_centers: + with gil: + for j in range(n_clusters): + weight_in_clusters[j] += weight_in_clusters_chunk[j] + for k in range(n_features): + centers_new[j, k] += centers_new_chunk[j * n_features + k] + + free(centers_new_chunk) + free(weight_in_clusters_chunk) + if update_centers: _relocate_empty_clusters_sparse( X_data, X_indices, X_indptr, sample_weight, @@ -328,10 +348,10 @@ cdef void _update_chunk_sparse(floating[::1] X_data, floating[::1] sample_weight, floating[::1] x_squared_norms, floating[:, ::1] centers_old, - floating[:, ::1] centers_new, floating[::1] centers_squared_norms, - floating[::1] weight_in_clusters, int[::1] labels, + floating *centers_new, + floating *weight_in_clusters, bint update_centers) nogil: """K-means combined EM step for one sparse data chunk. @@ -371,10 +391,7 @@ cdef void _update_chunk_sparse(floating[::1] X_data, labels[i] = label - if update_centers: - # The gil is necessary for that to avoid race conditions. - with gil: - for i in range(n_samples): - weight_in_clusters[labels[i]] += sample_weight[i] - for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): - centers_new[labels[i], X_indices[k]] += X_data[k] * sample_weight[i] + for i in range(n_samples): + weight_in_clusters[labels[i]] += sample_weight[i] + for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): + centers_new[labels[i] * n_features + X_indices[k]] += X_data[k] * sample_weight[i] diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 9e9cb40d0b844..33832139b783e 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -368,6 +368,9 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', # limit number of threads in second level of nested parallelism (i.e. BLAS) # to avoid oversubsciption + if n_jobs is None: + n_jobs = 1 + with thread_limits_context(limits=1, subset="blas"): for seed in seeds: # run a k-means once @@ -686,6 +689,9 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers, n_jobs=1): inertia : float Sum of squared distances of samples to their closest cluster center. """ + if n_jobs is None: + n_jobs = 1 + n_samples = X.shape[0] sample_weight = _check_sample_weight(X, sample_weight) labels = np.full(n_samples, -1, dtype=np.int32) From 34cd11edcd788718968acd7966b33cdcde47784e Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 24 Jun 2019 16:44:46 +0200 Subject: [PATCH 070/163] revert last changes: bad scalabbility (continued) --- sklearn/cluster/_k_means_elkan.pyx | 167 ++++++++++++++++------------- sklearn/cluster/_k_means_lloyd.pyx | 30 +++--- sklearn/cluster/k_means_.py | 21 ++-- 3 files changed, 116 insertions(+), 102 deletions(-) diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 4a6993b260b38..6b5f405d11523 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -260,39 +260,50 @@ shape (n_clusters, n_clusters) int i, j, k - # If n_samples < 256 there's still one chunk of size n_samples_rem - if n_chunks == 0: - n_chunks = 1 - n_samples_chunk = 0 + floating *centers_new_chunk + floating *weight_in_clusters_chunk + + # count remainder chunk in total number of chunks + n_chunks += n_samples != n_chunks * n_samples_chunk - # re-initialize all arrays at each iteration if update_centers: memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) - for chunk_idx in prange(n_chunks, nogil=True, num_threads=n_jobs): - # remaining samples added to last chunk - if chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_chunk + n_samples_rem - else: - n_samples_chunk_eff = n_samples_chunk - - start = chunk_idx * n_samples_chunk - end = start + n_samples_chunk_eff - - _update_chunk_dense( - &X[start, 0], - sample_weight[start: end], - centers_old, - centers_new, - center_half_distances, - distance_next_center, - weight_in_clusters, - labels[start: end], - upper_bounds[start: end], - lower_bounds[start: end], - update_centers) + with nogil, parallel(num_threads=n_jobs): + # thread local buffers + centers_new_chunk = calloc(n_clusters * n_features, sizeof(floating)) + weight_in_clusters_chunk = calloc(n_clusters, sizeof(floating)) + + for chunk_idx in prange(n_chunks): + start = chunk_idx * n_samples_chunk + if chunk_idx == n_chunks - 1 and n_samples_rem > 0: + end = start + n_samples_rem + else: + end = start + n_samples_chunk + + _update_chunk_dense( + &X[start, 0], + sample_weight[start: end], + centers_old, + center_half_distances, + distance_next_center, + labels[start: end], + upper_bounds[start: end], + lower_bounds[start: end], + centers_new_chunk, + weight_in_clusters_chunk, + update_centers) + + # reduction from local buffers. The gil is necessary for that to avoid + # race conditions. + if update_centers: + with gil: + for j in range(n_clusters): + weight_in_clusters[j] += weight_in_clusters_chunk[j] + for k in range(n_features): + centers_new[j, k] += centers_new_chunk[j * n_features + k] if update_centers: _relocate_empty_clusters_dense(X, sample_weight, centers_old, @@ -314,13 +325,13 @@ shape (n_clusters, n_clusters) cdef void _update_chunk_dense(floating *X, floating[::1] sample_weight, floating[:, ::1] centers_old, - floating[:, ::1] centers_new, floating[:, ::1] center_half_distances, floating[::1] distance_next_center, - floating[::1] weight_in_clusters, int[::1] labels, floating[::1] upper_bounds, floating[:, ::1] lower_bounds, + floating *centers_new, + floating *weight_in_clusters, bint update_centers) nogil: """K-means combined EM step for one dense data chunk. @@ -377,13 +388,10 @@ cdef void _update_chunk_dense(floating *X, labels[i] = label upper_bounds[i] = upper_bound - if update_centers: - # The gil is necessary for that to avoid race conditions. - with gil: - for i in range(n_samples): - weight_in_clusters[labels[i]] += sample_weight[i] - for k in range(n_features): - centers_new[labels[i], k] += X[i * n_features + k] * sample_weight[i] + if update_centers: + weight_in_clusters[label] += sample_weight[i] + for k in range(n_features): + centers_new[label * n_features + k] += X[i * n_features + k] * sample_weight[i] cpdef void _elkan_iter_chunked_sparse(X, @@ -476,42 +484,54 @@ shape (n_clusters, n_clusters) floating[::1] centers_squared_norms = row_norms(centers_new, squared=True) - # If n_samples < 256 there's still one chunk of size n_samples_rem - if n_chunks == 0: - n_chunks = 1 - n_samples_chunk = 0 + floating *centers_new_chunk + floating *weight_in_clusters_chunk + + # count remainder chunk in total number of chunks + n_chunks += n_samples != n_chunks * n_samples_chunk - # re-initialize all arrays at each iteration if update_centers: memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) - for chunk_idx in prange(n_chunks, nogil=True, num_threads=n_jobs): - # remaining samples added to last chunk - if chunk_idx == n_chunks - 1: - n_samples_chunk_eff = n_samples_chunk + n_samples_rem - else: - n_samples_chunk_eff = n_samples_chunk - - start = chunk_idx * n_samples_chunk - end = start + n_samples_chunk_eff - - _update_chunk_sparse( - X_data[X_indptr[start]: X_indptr[end]], - X_indices[X_indptr[start]: X_indptr[end]], - X_indptr[start: end], - sample_weight[start: end], - centers_old, - centers_new, - centers_squared_norms, - center_half_distances, - distance_next_center, - weight_in_clusters, - labels[start: end], - upper_bounds[start: end], - lower_bounds[start: end], - update_centers) + with nogil, parallel(num_threads=n_jobs): + # thread local buffers + centers_new_chunk = calloc(n_clusters * n_features, sizeof(floating)) + weight_in_clusters_chunk = calloc(n_clusters, sizeof(floating)) + + for chunk_idx in prange(n_chunks): + start = chunk_idx * n_samples_chunk + if chunk_idx == n_chunks - 1 and n_samples_rem > 0: + end = start + n_samples_rem + else: + end = start + n_samples_chunk + + _update_chunk_sparse( + X_data[X_indptr[start]: X_indptr[end]], + X_indices[X_indptr[start]: X_indptr[end]], + X_indptr[start: end], + sample_weight[start: end], + centers_old, + centers_squared_norms, + center_half_distances, + distance_next_center, + labels[start: end], + upper_bounds[start: end], + lower_bounds[start: end], + centers_new_chunk, + weight_in_clusters_chunk, + update_centers) + + # reduction from local buffers. The gil is necessary for that to avoid + # race conditions. + if update_centers: + with gil: + for j in range(n_clusters): + weight_in_clusters[j] += weight_in_clusters_chunk[j] + for k in range(n_features): + centers_new[j, k] += centers_new_chunk[j * n_features + k] + if update_centers: _relocate_empty_clusters_sparse( @@ -536,14 +556,14 @@ cdef void _update_chunk_sparse(floating[::1] X_data, int[::1] X_indptr, floating[::1] sample_weight, floating[:, ::1] centers_old, - floating[:, ::1] centers_new, floating[::1] centers_squared_norms, floating[:, ::1] center_half_distances, floating[::1] distance_next_center, - floating[::1] weight_in_clusters, int[::1] labels, floating[::1] upper_bounds, floating[:, ::1] lower_bounds, + floating *centers_new, + floating *weight_in_clusters, bint update_centers) nogil: """K-means combined EM step for one sparse data chunk. @@ -604,10 +624,7 @@ cdef void _update_chunk_sparse(floating[::1] X_data, labels[i] = label upper_bounds[i] = upper_bound - if update_centers: - # The gil is necessary for that to avoid race conditions. - with gil: - for i in range(n_samples): - weight_in_clusters[labels[i]] += sample_weight[i] - for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): - centers_new[labels[i], X_indices[k]] += X_data[k] * sample_weight[i] + if update_centers: + weight_in_clusters[label] += sample_weight[i] + for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): + centers_new[label * n_features + X_indices[k]] += X_data[k] * sample_weight[i] diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index 8abcb817cd7a5..671ed46c34f59 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -8,7 +8,6 @@ cimport numpy as np cimport openmp from cython cimport floating from cython.parallel import prange, parallel -from libc.math cimport sqrt from libc.stdlib cimport malloc, calloc, free from libc.string cimport memset, memcpy from libc.float cimport DBL_MAX, FLT_MAX @@ -29,7 +28,6 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, floating[::1] x_squared_norms, floating[:, ::1] centers_old, floating[:, ::1] centers_new, - floating[::1] centers_squared_norms, floating[::1] weight_in_clusters, int[::1] labels, floating[::1] center_shift, @@ -97,6 +95,8 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, int j, k + floating[::1] centers_squared_norms = row_norms(centers_new, squared=True) + floating *centers_new_chunk floating *weight_in_clusters_chunk floating *pairwise_distances_chunk @@ -104,8 +104,6 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, # count remainder chunk in total number of chunks n_chunks += n_samples != n_chunks * n_samples_chunk - # re-initialize all arrays at each iteration - centers_squared_norms = row_norms(centers_new, squared=True) if update_centers: memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) @@ -202,12 +200,10 @@ cdef void _update_chunk_dense(floating *X, label = j labels[i] = label - # XXX try inside prev loop - if update_centers: - for i in range(n_samples): - weight_in_clusters[labels[i]] += sample_weight[i] + if update_centers: + weight_in_clusters[label] += sample_weight[i] for k in range(n_features): - centers_new[labels[i] * n_features + k] += X[i * n_features + k] * sample_weight[i] + centers_new[label * n_features + k] += X[i * n_features + k] * sample_weight[i] cpdef void _lloyd_iter_chunked_sparse(X, @@ -215,7 +211,6 @@ cpdef void _lloyd_iter_chunked_sparse(X, floating[::1] x_squared_norms, floating[:, ::1] centers_old, floating[:, ::1] centers_new, - floating[::1] centers_squared_norms, floating[::1] weight_in_clusters, int[::1] labels, floating[::1] center_shift, @@ -286,11 +281,14 @@ cpdef void _lloyd_iter_chunked_sparse(X, int[::1] X_indices = X.indices int[::1] X_indptr = X.indptr + floating[::1] centers_squared_norms = row_norms(centers_new, squared=True) + + floating *centers_new_chunk + floating *weight_in_clusters_chunk + # count remainder chunk in total number of chunks n_chunks += n_samples != n_chunks * n_samples_chunk - # re-initialize all arrays at each iteration - centers_squared_norms = row_norms(centers_new, squared=True) if update_centers: memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) @@ -391,7 +389,7 @@ cdef void _update_chunk_sparse(floating[::1] X_data, labels[i] = label - for i in range(n_samples): - weight_in_clusters[labels[i]] += sample_weight[i] - for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): - centers_new[labels[i] * n_features + X_indices[k]] += X_data[k] * sample_weight[i] + if update_centers: + weight_in_clusters[label] += sample_weight[i] + for k in range(X_indptr[i] - s, X_indptr[i + 1] - s): + centers_new[label * n_features + X_indices[k]] += X_data[k] * sample_weight[i] diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 33832139b783e..c0db15c918803 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -622,7 +622,6 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, print("Initialization complete") centers_old = np.zeros_like(centers) - centers_squared_norms = np.zeros(n_clusters, dtype=X.dtype) labels = np.full(X.shape[0], -1, dtype=np.int32) weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype) center_shift = np.zeros(n_clusters, dtype=X.dtype) @@ -636,8 +635,7 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, for i in range(max_iter): lloyd_iter(X, sample_weight, x_squared_norms, centers_old, centers, - centers_squared_norms, weight_in_clusters, labels, - center_shift, n_jobs) + weight_in_clusters, labels, center_shift, n_jobs) if verbose: inertia = _inertia(X, sample_weight, centers_old, labels) @@ -653,8 +651,8 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, # rerun E-step so that predicted labels match cluster centers lloyd_iter(X, sample_weight, x_squared_norms, centers, centers, - centers_squared_norms, weight_in_clusters, labels, - center_shift, n_jobs, update_centers=False) + weight_in_clusters, labels, center_shift, n_jobs, + update_centers=False) inertia = _inertia(X, sample_weight, centers, labels) @@ -693,11 +691,12 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers, n_jobs=1): n_jobs = 1 n_samples = X.shape[0] + n_clusters = centers.shape[0] + sample_weight = _check_sample_weight(X, sample_weight) labels = np.full(n_samples, -1, dtype=np.int32) - centers_squared_norms = np.zeros(centers.shape[0], dtype=centers.dtype) - weight_in_clusters = np.zeros_like(centers_squared_norms) - center_shift = np.zeros_like(centers_squared_norms) + weight_in_clusters = np.zeros(n_clusters, dtype=centers.dtype) + center_shift = np.zeros_like(weight_in_clusters) if sp.issparse(X): _labels = _lloyd_iter_chunked_sparse @@ -706,9 +705,9 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers, n_jobs=1): _labels = _lloyd_iter_chunked_dense _inertia = _inertia_dense - _labels(X, sample_weight, x_squared_norms, centers, - centers, centers_squared_norms, weight_in_clusters, - labels, center_shift, n_jobs, update_centers=False) + _labels(X, sample_weight, x_squared_norms, centers, centers, + weight_in_clusters, labels, center_shift, n_jobs, + update_centers=False) inertia = _inertia(X, sample_weight, centers, labels) From 6c13a7d5326539ee163c5b6ccb108d71befd1679 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 24 Jun 2019 16:59:58 +0200 Subject: [PATCH 071/163] merge master --- sklearn/cluster/_k_means_lloyd.pyx | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index 671ed46c34f59..e7c1ff8839d66 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -1,5 +1,4 @@ -# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False -# cython: language_level=3 +# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True # # Licence: BSD 3 clause From d8439fd7cf8b41bac60970952dbe9778c91ef461 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 26 Jun 2019 14:08:16 +0200 Subject: [PATCH 072/163] openmp helper equivalent of effective_n_jobs --- sklearn/utils/openmp_helpers.pyx | 26 ++++++++++++++++++++++++++ sklearn/utils/setup.py | 4 ++++ 2 files changed, 30 insertions(+) create mode 100644 sklearn/utils/openmp_helpers.pyx diff --git a/sklearn/utils/openmp_helpers.pyx b/sklearn/utils/openmp_helpers.pyx new file mode 100644 index 0000000000000..5c8c06cd160cf --- /dev/null +++ b/sklearn/utils/openmp_helpers.pyx @@ -0,0 +1,26 @@ +cimport openmp +from joblib import effective_n_jobs + + +cpdef _openmp_effective_n_threads(n_threads=None): + """Determine the effective number of threads used for parallel OpenMP calls + + - For ``n_threads = None``, returns the minimum between + openmp.omp_get_max_threads() and joblib.effective_n_jobs(-1). + - For ``n_threads > 0``, use this as the maximal number of threads for + parallel OpenMP calls. + - For ``n_threads < 0``, use the maximal number of threads minus + ``|n_threads + 1|``. + - Raise a ValueError for ``n_threads = 0``. + """ + if n_threads == 0: + raise ValueError("n_threads = 0 is invalid") + + max_threads = min(openmp.omp_get_max_threads(), effective_n_jobs(-1)) + + if n_threads is None: + return max_threads + elif n_threads < 0: + return max(1, max_threads + n_threads + 1) + + return n_threads diff --git a/sklearn/utils/setup.py b/sklearn/utils/setup.py index f3002ed3ffed9..593739915f3f8 100644 --- a/sklearn/utils/setup.py +++ b/sklearn/utils/setup.py @@ -46,6 +46,10 @@ def configuration(parent_package='', top_path=None): include_dirs=[numpy.get_include()], libraries=libraries) + config.add_extension('openmp_helpers', + sources=['openmp_helpers.pyx'], + libraries=libraries) + # generate files from a template pyx_templates = ['sklearn/utils/seq_dataset.pyx.tp', 'sklearn/utils/seq_dataset.pxd.tp'] From b8900ab23c7311f3508ca9c684b584c4876a0af6 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 26 Jun 2019 14:58:03 +0200 Subject: [PATCH 073/163] protect openmp calls --- sklearn/utils/openmp_helpers.pyx | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/sklearn/utils/openmp_helpers.pyx b/sklearn/utils/openmp_helpers.pyx index 5c8c06cd160cf..ff972de2d10e3 100644 --- a/sklearn/utils/openmp_helpers.pyx +++ b/sklearn/utils/openmp_helpers.pyx @@ -1,4 +1,6 @@ -cimport openmp +IF SKLEARN_OPENMP_SUPPORTED: + cimport openmp + from joblib import effective_n_jobs @@ -12,15 +14,23 @@ cpdef _openmp_effective_n_threads(n_threads=None): - For ``n_threads < 0``, use the maximal number of threads minus ``|n_threads + 1|``. - Raise a ValueError for ``n_threads = 0``. + + If scikit-learn is built without OpenMP support, always return 1. """ if n_threads == 0: raise ValueError("n_threads = 0 is invalid") - max_threads = min(openmp.omp_get_max_threads(), effective_n_jobs(-1)) + IF SKLEARN_OPENMP_SUPPORTED: + max_threads = min(openmp.omp_get_max_threads(), effective_n_jobs(-1)) + + if n_threads is None: + return max_threads + elif n_threads < 0: + return max(1, max_threads + n_threads + 1) - if n_threads is None: - return max_threads - elif n_threads < 0: - return max(1, max_threads + n_threads + 1) + return n_threads + ELSE: + # OpenMP not supported => sequential mode + return 1 - return n_threads + From e47bdb842765a6fdf1178f52887a0d87864e8274 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 26 Jun 2019 17:30:27 +0200 Subject: [PATCH 074/163] comment openmp max threads --- sklearn/utils/openmp_helpers.pyx | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/openmp_helpers.pyx b/sklearn/utils/openmp_helpers.pyx index ff972de2d10e3..d45ba73595028 100644 --- a/sklearn/utils/openmp_helpers.pyx +++ b/sklearn/utils/openmp_helpers.pyx @@ -21,12 +21,14 @@ cpdef _openmp_effective_n_threads(n_threads=None): raise ValueError("n_threads = 0 is invalid") IF SKLEARN_OPENMP_SUPPORTED: - max_threads = min(openmp.omp_get_max_threads(), effective_n_jobs(-1)) + # omp_get_max_threads can be influenced by environement variable + # OMP_NUM_THREADS or at runtime by omp_set_num_threads + max_n_threads = min(openmp.omp_get_max_threads(), effective_n_jobs(-1)) if n_threads is None: - return max_threads + return max_n_threads elif n_threads < 0: - return max(1, max_threads + n_threads + 1) + return max(1, max_n_threads + n_threads + 1) return n_threads ELSE: From 8050149a2a3a59b8a0117d51234841a64cffc4cd Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 26 Jun 2019 17:32:14 +0200 Subject: [PATCH 075/163] right place comment --- sklearn/utils/openmp_helpers.pyx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/openmp_helpers.pyx b/sklearn/utils/openmp_helpers.pyx index d45ba73595028..296473aa13499 100644 --- a/sklearn/utils/openmp_helpers.pyx +++ b/sklearn/utils/openmp_helpers.pyx @@ -9,6 +9,8 @@ cpdef _openmp_effective_n_threads(n_threads=None): - For ``n_threads = None``, returns the minimum between openmp.omp_get_max_threads() and joblib.effective_n_jobs(-1). + The result of ``omp_get_max_threads`` can be influenced by environement + variable ``OMP_NUM_THREADS`` or at runtime by ``omp_set_num_threads``. - For ``n_threads > 0``, use this as the maximal number of threads for parallel OpenMP calls. - For ``n_threads < 0``, use the maximal number of threads minus @@ -21,8 +23,6 @@ cpdef _openmp_effective_n_threads(n_threads=None): raise ValueError("n_threads = 0 is invalid") IF SKLEARN_OPENMP_SUPPORTED: - # omp_get_max_threads can be influenced by environement variable - # OMP_NUM_THREADS or at runtime by omp_set_num_threads max_n_threads = min(openmp.omp_get_max_threads(), effective_n_jobs(-1)) if n_threads is None: From 753272203477a647f0cde2ad6f33e1a01347bcb8 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 27 Jun 2019 10:56:18 +0200 Subject: [PATCH 076/163] avoid copy centers_old <-> centers_new --- sklearn/cluster/_k_means_lloyd.pyx | 6 ++---- sklearn/cluster/k_means_.py | 8 +++++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sklearn/cluster/_k_means_lloyd.pyx b/sklearn/cluster/_k_means_lloyd.pyx index e7c1ff8839d66..edbc882439588 100644 --- a/sklearn/cluster/_k_means_lloyd.pyx +++ b/sklearn/cluster/_k_means_lloyd.pyx @@ -94,7 +94,7 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, int j, k - floating[::1] centers_squared_norms = row_norms(centers_new, squared=True) + floating[::1] centers_squared_norms = row_norms(centers_old, squared=True) floating *centers_new_chunk floating *weight_in_clusters_chunk @@ -104,7 +104,6 @@ cpdef void _lloyd_iter_chunked_dense(np.ndarray[floating, ndim=2, mode='c'] X, n_chunks += n_samples != n_chunks * n_samples_chunk if update_centers: - memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) @@ -280,7 +279,7 @@ cpdef void _lloyd_iter_chunked_sparse(X, int[::1] X_indices = X.indices int[::1] X_indptr = X.indptr - floating[::1] centers_squared_norms = row_norms(centers_new, squared=True) + floating[::1] centers_squared_norms = row_norms(centers_old, squared=True) floating *centers_new_chunk floating *weight_in_clusters_chunk @@ -289,7 +288,6 @@ cpdef void _lloyd_iter_chunked_sparse(X, n_chunks += n_samples != n_chunks * n_samples_chunk if update_centers: - memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index d2077b9ac3134..83c3d1ce360b9 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -621,7 +621,7 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, if verbose: print("Initialization complete") - centers_old = np.zeros_like(centers) + centers_new = np.zeros_like(centers) labels = np.full(X.shape[0], -1, dtype=np.int32) weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype) center_shift = np.zeros(n_clusters, dtype=X.dtype) @@ -634,11 +634,11 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, _inertia = _inertia_dense for i in range(max_iter): - lloyd_iter(X, sample_weight, x_squared_norms, centers_old, centers, + lloyd_iter(X, sample_weight, x_squared_norms, centers, centers_new, weight_in_clusters, labels, center_shift, n_jobs) if verbose: - inertia = _inertia(X, sample_weight, centers_old, labels) + inertia = _inertia(X, sample_weight, centers, labels) print("Iteration {0}, inertia {1}" .format(i, inertia)) center_shift_tot = (center_shift**2).sum() @@ -649,6 +649,8 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, .format(i, center_shift_tot, tol)) break + centers, centers_new = centers_new, centers + # rerun E-step so that predicted labels match cluster centers lloyd_iter(X, sample_weight, x_squared_norms, centers, centers, weight_in_clusters, labels, center_shift, n_jobs, From 54f814688ad432e5c09b5b1e2441281bfdb53beb Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 27 Jun 2019 11:08:14 +0200 Subject: [PATCH 077/163] avoid copy centers_old <-> centers_new --- sklearn/cluster/_k_means_elkan.pyx | 4 +--- sklearn/cluster/k_means_.py | 10 ++++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 734d4d359c053..d7bbdb6fc647e 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -266,7 +266,6 @@ shape (n_clusters, n_clusters) n_chunks += n_samples != n_chunks * n_samples_chunk if update_centers: - memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) @@ -481,7 +480,7 @@ shape (n_clusters, n_clusters) int i, j, k - floating[::1] centers_squared_norms = row_norms(centers_new, squared=True) + floating[::1] centers_squared_norms = row_norms(centers_old, squared=True) floating *centers_new_chunk floating *weight_in_clusters_chunk @@ -490,7 +489,6 @@ shape (n_clusters, n_clusters) n_chunks += n_samples != n_chunks * n_samples_chunk if update_centers: - memcpy(¢ers_old[0, 0], ¢ers_new[0, 0], n_clusters * n_features * sizeof(floating)) memset(¢ers_new[0, 0], 0, n_clusters * n_features * sizeof(floating)) memset(&weight_in_clusters[0], 0, n_clusters * sizeof(floating)) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 83c3d1ce360b9..67d677f654b7d 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -485,7 +485,7 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, n_samples = X.shape[0] - centers_old = np.zeros_like(centers) + centers_new = np.zeros_like(centers) weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype) labels = np.full(n_samples, -1, dtype=np.int32) center_half_distances = euclidean_distances(centers) / 2 @@ -508,18 +508,18 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, labels, upper_bounds, lower_bounds) for i in range(max_iter): - elkan_iter(X, sample_weight, centers_old, centers, weight_in_clusters, + elkan_iter(X, sample_weight, centers, centers_new, weight_in_clusters, center_half_distances, distance_next_center, upper_bounds, lower_bounds, labels, center_shift, n_jobs) # compute new pairwise distances between centers and closest other # center of each center for next iterations - center_half_distances = euclidean_distances(centers) / 2 + center_half_distances = euclidean_distances(centers_new) / 2 distance_next_center = np.partition(np.asarray(center_half_distances), kth=1, axis=0)[1] if verbose: - inertia = _inertia(X, sample_weight, centers_old, labels) + inertia = _inertia(X, sample_weight, centers, labels) print("Iteration {0}, inertia {1}" .format(i, inertia)) center_shift_tot = (center_shift**2).sum() @@ -530,6 +530,8 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, .format(i, center_shift_tot, tol)) break + centers, centers_new = centers_new, centers + # rerun E-step so that predicted labels match cluster centers elkan_iter(X, sample_weight, centers, centers, weight_in_clusters, center_half_distances, distance_next_center, upper_bounds, From 280f5516f895007bf26a5a6940b580709f1747e5 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 5 Jul 2019 15:18:58 +0200 Subject: [PATCH 078/163] don't import joblib if unecessary --- sklearn/utils/openmp_helpers.pyx | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/utils/openmp_helpers.pyx b/sklearn/utils/openmp_helpers.pyx index 296473aa13499..9fb01b3d29e2f 100644 --- a/sklearn/utils/openmp_helpers.pyx +++ b/sklearn/utils/openmp_helpers.pyx @@ -1,7 +1,6 @@ IF SKLEARN_OPENMP_SUPPORTED: cimport openmp - -from joblib import effective_n_jobs + from joblib import effective_n_jobs cpdef _openmp_effective_n_threads(n_threads=None): From 8f5ebfde1b09e587750bedee18e2332b87342d32 Mon Sep 17 00:00:00 2001 From: jeremiedbb <34657725+jeremiedbb@users.noreply.github.com> Date: Mon, 5 Aug 2019 01:04:55 +0200 Subject: [PATCH 079/163] Update sklearn/utils/openmp_helpers.pyx Co-Authored-By: Joel Nothman --- sklearn/utils/openmp_helpers.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/openmp_helpers.pyx b/sklearn/utils/openmp_helpers.pyx index 9fb01b3d29e2f..d798024bc269d 100644 --- a/sklearn/utils/openmp_helpers.pyx +++ b/sklearn/utils/openmp_helpers.pyx @@ -8,7 +8,7 @@ cpdef _openmp_effective_n_threads(n_threads=None): - For ``n_threads = None``, returns the minimum between openmp.omp_get_max_threads() and joblib.effective_n_jobs(-1). - The result of ``omp_get_max_threads`` can be influenced by environement + The result of ``omp_get_max_threads`` can be influenced by environment variable ``OMP_NUM_THREADS`` or at runtime by ``omp_set_num_threads``. - For ``n_threads > 0``, use this as the maximal number of threads for parallel OpenMP calls. From edebabf6960ef68c828892a45dd4aba3e3e15d2a Mon Sep 17 00:00:00 2001 From: jeremiedbb <34657725+jeremiedbb@users.noreply.github.com> Date: Fri, 9 Aug 2019 15:48:21 +0200 Subject: [PATCH 080/163] Update sklearn/utils/openmp_helpers.pyx Co-Authored-By: Thomas Moreau --- sklearn/utils/openmp_helpers.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/openmp_helpers.pyx b/sklearn/utils/openmp_helpers.pyx index d798024bc269d..bf80c77c453bf 100644 --- a/sklearn/utils/openmp_helpers.pyx +++ b/sklearn/utils/openmp_helpers.pyx @@ -1,6 +1,6 @@ IF SKLEARN_OPENMP_SUPPORTED: cimport openmp - from joblib import effective_n_jobs + from joblib import cpu_count cpdef _openmp_effective_n_threads(n_threads=None): From 4e994521529d60e317d1df9dd6fdc9e82b9d078c Mon Sep 17 00:00:00 2001 From: jeremiedbb <34657725+jeremiedbb@users.noreply.github.com> Date: Fri, 9 Aug 2019 15:48:35 +0200 Subject: [PATCH 081/163] Update sklearn/utils/openmp_helpers.pyx Co-Authored-By: Thomas Moreau --- sklearn/utils/openmp_helpers.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/openmp_helpers.pyx b/sklearn/utils/openmp_helpers.pyx index bf80c77c453bf..dd035fc1ef481 100644 --- a/sklearn/utils/openmp_helpers.pyx +++ b/sklearn/utils/openmp_helpers.pyx @@ -22,7 +22,7 @@ cpdef _openmp_effective_n_threads(n_threads=None): raise ValueError("n_threads = 0 is invalid") IF SKLEARN_OPENMP_SUPPORTED: - max_n_threads = min(openmp.omp_get_max_threads(), effective_n_jobs(-1)) + max_n_threads = min(openmp.omp_get_max_threads(), cpu_count()) if n_threads is None: return max_n_threads From 0a9545025d0ab6638c052223a45a5c3c91ea3fa6 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 13 Sep 2019 18:46:51 +0200 Subject: [PATCH 082/163] vendor threadpoolctl --- sklearn/externals/_threadpoolctl.py | 582 ++++++++++++++++++++++ sklearn/externals/vendor_threadpoolctl.sh | 30 ++ 2 files changed, 612 insertions(+) create mode 100644 sklearn/externals/_threadpoolctl.py create mode 100755 sklearn/externals/vendor_threadpoolctl.sh diff --git a/sklearn/externals/_threadpoolctl.py b/sklearn/externals/_threadpoolctl.py new file mode 100644 index 0000000000000..524ae02dc09d6 --- /dev/null +++ b/sklearn/externals/_threadpoolctl.py @@ -0,0 +1,582 @@ +"""threadpoolctl + +This module provides utilities to introspect native libraries that relies on +thread pools (notably BLAS and OpenMP implementations) and dynamically set the +maximal number of threads they can use. +""" +# License: BSD 3-Clause + +# The code to introspect dynamically loaded libraries on POSIX systems is +# adapted from code by Intel developper @anton-malakhov available at +# https://github.com/IntelPython/smp (Copyright (c) 2017, Intel Corporation) +# and also published under the BSD 3-Clause license +import os +import re +import sys +import ctypes +import warnings +from ctypes.util import find_library + +__version__ = '1.1.0' +__all__ = ["threadpool_limits", "threadpool_info"] + +# Cache for libc under POSIX and a few system libraries under Windows +_system_libraries = {} + +# Cache for calls to os.path.realpath on system libraries to reduce the +# impact of slow system calls (e.g. stat) on slow filesystem +_realpaths = dict() + +# One can get runtime errors or even segfaults due to multiple OpenMP libraries +# loaded simultaneously which can happen easily in Python when importing and +# using compiled extensions built with different compilers and therefore +# different OpenMP runtimes in the same program. In particular libiomp (used by +# Intel ICC) and libomp used by clang/llvm tend to crash. This can happen for +# instance when calling BLAS inside a prange. Setting the following environment +# variable allows multiple OpenMP libraries to be loaded. It should not degrade +# performances since we manually take care of potential over-subscription +# performance issues, in sections of the code where nested OpenMP loops can +# happen, by dynamically reconfiguring the inner OpenMP runtime to temporarily +# disable it while under the scope of the outer OpenMP parallel section. +os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "True") + + +# Structure to cast the info on dynamically loaded library. See +# https://linux.die.net/man/3/dl_iterate_phdr for more details. + +_SYSTEM_UINT = ctypes.c_uint64 if sys.maxsize > 2**32 else ctypes.c_uint32 +_SYSTEM_UINT_HALF = ctypes.c_uint32 if sys.maxsize > 2**32 else ctypes.c_uint16 + + +class _dl_phdr_info(ctypes.Structure): + _fields_ = [ + ("dlpi_addr", _SYSTEM_UINT), # Base address of object + ("dlpi_name", ctypes.c_char_p), # path to the library + ("dlpi_phdr", ctypes.c_void_p), # pointer on dlpi_headers + ("dlpi_phnum", _SYSTEM_UINT_HALF) # number of element in dlpi_phdr + ] + + +# List of the supported implementations. The items hold the prefix of loaded +# shared objects, the name of the internal_api to call, matching the +# MAP_API_TO_FUNC keys and the name of the user_api, in {"blas", "openmp"}. + +_SUPPORTED_IMPLEMENTATIONS = [ + { + "user_api": "openmp", + "internal_api": "openmp", + "filename_prefixes": ("libiomp", "libgomp", "libomp", "vcomp",), + }, + { + "user_api": "blas", + "internal_api": "openblas", + "filename_prefixes": ("libopenblas",), + }, + { + "user_api": "blas", + "internal_api": "mkl", + "filename_prefixes": ("libmkl_rt", "mkl_rt",), + }, + { + "user_api": "blas", + "internal_api": "blis", + "filename_prefixes": ("libblis",), + }, +] + +# map a internal_api (openmp, openblas, mkl) to set and get functions + +_MAP_API_TO_FUNC = { + "openmp": { + "set_num_threads": "omp_set_num_threads", + "get_num_threads": "omp_get_max_threads"}, + "openblas": { + "set_num_threads": "openblas_set_num_threads", + "get_num_threads": "openblas_get_num_threads"}, + "mkl": { + "set_num_threads": "MKL_Set_Num_Threads", + "get_num_threads": "MKL_Get_Max_Threads"}, + "blis": { + "set_num_threads": "bli_thread_set_num_threads", + "get_num_threads": "bli_thread_get_num_threads"} +} + +# Helpers for the doc and test names + +_ALL_USER_APIS = set(impl['user_api'] for impl in _SUPPORTED_IMPLEMENTATIONS) +_ALL_PREFIXES = [prefix + for impl in _SUPPORTED_IMPLEMENTATIONS + for prefix in impl['filename_prefixes']] +_ALL_INTERNAL_APIS = list(_MAP_API_TO_FUNC.keys()) + + +def _realpath(filepath, cache_limit=10000): + """Small caching wrapper around os.path.realpath to limit system calls""" + rpath = _realpaths.get(filepath) + if rpath is None: + rpath = os.path.realpath(filepath) + if len(_realpaths) < cache_limit: + # If we drop support for Python 2.7, we could use functools.lru_cache + # with maxsize=10000 instead. + _realpaths[filepath] = rpath + return rpath + + +def _format_docstring(*args, **kwargs): + def decorator(o): + o.__doc__ = o.__doc__.format(*args, **kwargs) + return o + + return decorator + + +def _get_limit(prefix, user_api, limits): + if prefix in limits: + return limits[prefix] + else: + return limits[user_api] + + +@_format_docstring(ALL_PREFIXES=_ALL_PREFIXES, + INTERNAL_APIS=_ALL_INTERNAL_APIS) +def _set_threadpool_limits(limits, user_api=None): + """Limit the maximal number of threads for threadpools in supported libs + + Set the maximal number of threads that can be used in thread pools used in + the supported native libraries to `limit`. This function works for + libraries that are already loaded in the interpreter and can be changed + dynamically. + + The `limits` parameter can be either an integer or a dict to specify the + maximal number of thread that can be used in thread pools. If it is an + integer, sets the maximum number of thread to `limits` for each library + selected by `user_api`. If it is a dictionary `{{key: max_threads}}`, this + function sets a custom maximum number of thread for each `key` which can be + either a `user_api` or a `prefix` for a specific library. + + The `user_api` parameter selects particular APIs of libraries to limit. + Used only if `limits` is an int. If it is None, this function will apply to + all supported libraries. If it is "blas", it will limit only BLAS supported + libraries and if it is "openmp", only OpenMP supported libraries will be + limited. Note that the latter can affect the number of threads used by the + BLAS libraries if they rely on OpenMP. + + Return a list with all the supported modules that have been found. Each + module is represented by a dict with the following information: + - 'filename_prefixes' : possible prefixes for the given internal_api. + Possible values are {ALL_PREFIXES}. + - 'prefix' : prefix of the specific implementation of this module. + - 'internal_api': internal API.s Possible values are {INTERNAL_APIS}. + - 'filepath': path to the loaded module. + - 'version': version of the library implemented (if available). + - 'num_threads': the theadpool size limit before changing it. + - 'set_num_threads': callable to set the maximum number of threads + - 'get_num_threads': callable to get the current number of threads + - 'dynlib': the instance of ctypes.CDLL use to access the dynamic + library. + """ + if isinstance(limits, int): + if user_api is None: + user_api = _ALL_USER_APIS + elif user_api in _ALL_USER_APIS: + user_api = (user_api,) + else: + raise ValueError("user_api must be either in {} or None. Got {} " + "instead.".format(_ALL_USER_APIS, user_api)) + limits = {api: limits for api in user_api} + prefixes = [] + else: + if isinstance(limits, list): + # This should be a list of module, for compatibility with + # the result from threadpool_info. + limits = {module['prefix']: module['num_threads'] + for module in limits} + + if not isinstance(limits, dict): + raise TypeError("limits must either be an int, a list or a dict." + " Got {} instead".format(type(limits))) + + # With a dictionary, can set both specific limit for given modules + # and global limit for user_api. Fetch each separately. + prefixes = [module for module in limits if module in _ALL_PREFIXES] + user_api = [module for module in limits if module in _ALL_USER_APIS] + + modules = _load_modules(prefixes=prefixes, user_api=user_api) + for module in modules: + # Workaround clang bug (TODO: report it) + module['get_num_threads']() + + for module in modules: + module['num_threads'] = module['get_num_threads']() + num_threads = _get_limit(module['prefix'], module['user_api'], limits) + if num_threads is not None: + set_func = module['set_num_threads'] + set_func(num_threads) + + return modules + + +@_format_docstring(INTERNAL_APIS=_ALL_INTERNAL_APIS) +def threadpool_info(): + """Return the maximal number of threads for each detected library. + + Return a list with all the supported modules that have been found. Each + module is represented by a dict with the following information: + - 'prefix' : filename prefix of the specific implementation. + - 'filepath': path to the loaded module. + - 'internal_api': internal API. Possible values are {INTERNAL_APIS}. + - 'version': version of the library implemented (if available). + - 'num_threads': the current thread limit. + """ + infos = [] + modules = _load_modules(user_api=_ALL_USER_APIS) + for module in modules: + module['num_threads'] = module['get_num_threads']() + # by default BLIS is single-threaded and get_num_threads returns -1. + # we map it to 1 for consistency with other libraries. + if module['num_threads'] == -1 and module['internal_api'] == 'blis': + module['num_threads'] = 1 + # Remove the wrapper for the module and its function + del module['set_num_threads'], module['get_num_threads'] + del module['dynlib'] + del module['filename_prefixes'] + infos.append(module) + return infos + + +def _get_version(dynlib, internal_api): + if internal_api == "mkl": + return _get_mkl_version(dynlib) + elif internal_api == "openmp": + # There is no way to get the version number programmatically in + # OpenMP. + return None + elif internal_api == "openblas": + return _get_openblas_version(dynlib) + elif internal_api == "blis": + return _get_blis_version(dynlib) + else: + raise NotImplementedError("Unsupported API {}".format(internal_api)) + + +def _get_mkl_version(mkl_dynlib): + """Return the MKL version""" + res = ctypes.create_string_buffer(200) + mkl_dynlib.mkl_get_version_string(res, 200) + + version = res.value.decode('utf-8') + group = re.search(r"Version ([^ ]+) ", version) + if group is not None: + version = group.groups()[0] + return version.strip() + + +def _get_openblas_version(openblas_dynlib): + """Return the OpenBLAS version + + None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS + did not expose its version before that. + """ + get_config = getattr(openblas_dynlib, "openblas_get_config") + get_config.restype = ctypes.c_char_p + config = get_config().split() + if config[0] == b"OpenBLAS": + return config[1].decode('utf-8') + return None + + +def _get_blis_version(blis_dynlib): + """Return the BLIS version""" + get_version = getattr(blis_dynlib, "bli_info_get_version_str") + get_version.restype = ctypes.c_char_p + return get_version().decode('utf-8') + + +# Loading utilities for dynamically linked shared objects + +def _load_modules(prefixes=None, user_api=None): + """Loop through loaded libraries and return supported ones.""" + if prefixes is None: + prefixes = [] + if user_api is None: + user_api = [] + if sys.platform == "darwin": + return _find_modules_with_dyld(prefixes=prefixes, user_api=user_api) + elif sys.platform == "win32": + return _find_modules_with_enum_process_module_ex( + prefixes=prefixes, user_api=user_api) + else: + return _find_modules_with_dl_iterate_phdr( + prefixes=prefixes, user_api=user_api) + + +def _check_prefix(library_basename, filename_prefixes): + """Return the prefix library_basename starts with or None if none matches + """ + for prefix in filename_prefixes: + if library_basename.startswith(prefix): + return prefix + return None + + +def _match_module(module_info, prefix, prefixes, user_api): + """Return True if this module should be selected.""" + return prefix is not None and (prefix in prefixes or + module_info['user_api'] in user_api) + + +def _make_module_info(filepath, module_info, prefix): + """Make a dict with the information from the module.""" + filepath = os.path.normpath(filepath) + dynlib = ctypes.CDLL(filepath) + internal_api = module_info['internal_api'] + set_func = getattr(dynlib, + _MAP_API_TO_FUNC[internal_api]['set_num_threads'], + lambda num_threads: None) + get_func = getattr(dynlib, + _MAP_API_TO_FUNC[internal_api]['get_num_threads'], + lambda: None) + module_info = module_info.copy() + module_info.update(dynlib=dynlib, filepath=filepath, prefix=prefix, + set_num_threads=set_func, get_num_threads=get_func, + version=_get_version(dynlib, internal_api)) + return module_info + + +def _get_module_info_from_path(filepath, prefixes, user_api, modules): + # Required to resolve symlinks + filepath =_realpath(filepath) + # `lower` required to take account of OpenMP dll case on Windows + # (vcomp, VCOMP, Vcomp, ...) + filename = os.path.basename(filepath).lower() + for info in _SUPPORTED_IMPLEMENTATIONS: + prefix = _check_prefix(filename, info['filename_prefixes']) + if _match_module(info, prefix, prefixes, user_api): + modules.append(_make_module_info(filepath, info, prefix)) + + +def _find_modules_with_dl_iterate_phdr(prefixes, user_api): + """Loop through loaded libraries and return binders on supported ones + + This function is expected to work on POSIX system only. + This code is adapted from code by Intel developper @anton-malakhov + available at https://github.com/IntelPython/smp + + Copyright (c) 2017, Intel Corporation published under the BSD 3-Clause + license + """ + libc = _get_libc() + if not hasattr(libc, "dl_iterate_phdr"): # pragma: no cover + return [] + + _modules = [] + + # Callback function for `dl_iterate_phdr` which is called for every + # module loaded in the current process until it returns 1. + def match_module_callback(info, size, data): + # Get the path of the current module + filepath = info.contents.dlpi_name + if filepath: + filepath = filepath.decode("utf-8") + + # Store the module in cls_thread_locals._module if it is + # supported and selected + _get_module_info_from_path(filepath, prefixes, user_api, + _modules) + return 0 + + c_func_signature = ctypes.CFUNCTYPE( + ctypes.c_int, # Return type + ctypes.POINTER(_dl_phdr_info), ctypes.c_size_t, ctypes.c_char_p) + c_match_module_callback = c_func_signature(match_module_callback) + + data = ctypes.c_char_p(b'') + libc.dl_iterate_phdr(c_match_module_callback, data) + + return _modules + + +def _find_modules_with_dyld(prefixes, user_api): + """Loop through loaded libraries and return binders on supported ones + + This function is expected to work on OSX system only + """ + libc = _get_libc() + if not hasattr(libc, "_dyld_image_count"): # pragma: no cover + return [] + + _modules = [] + + n_dyld = libc._dyld_image_count() + libc._dyld_get_image_name.restype = ctypes.c_char_p + + for i in range(n_dyld): + filepath = ctypes.string_at(libc._dyld_get_image_name(i)) + filepath = filepath.decode("utf-8") + + # Store the module in cls_thread_locals._module if it is supported and + # selected + _get_module_info_from_path(filepath, prefixes, user_api, _modules) + + return _modules + + +def _find_modules_with_enum_process_module_ex(prefixes, user_api): + """Loop through loaded libraries and return binders on supported ones + + This function is expected to work on windows system only. + This code is adapted from code by Philipp Hagemeister @phihag available + at https://stackoverflow.com/questions/17474574 + """ + from ctypes.wintypes import DWORD, HMODULE, MAX_PATH + + PROCESS_QUERY_INFORMATION = 0x0400 + PROCESS_VM_READ = 0x0010 + + LIST_MODULES_ALL = 0x03 + + ps_api = _get_windll('Psapi') + kernel_32 = _get_windll('kernel32') + + h_process = kernel_32.OpenProcess( + PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, + False, os.getpid()) + if not h_process: # pragma: no cover + raise OSError('Could not open PID %s' % os.getpid()) + + _modules = [] + try: + buf_count = 256 + needed = DWORD() + # Grow the buffer until it becomes large enough to hold all the + # module headers + while True: + buf = (HMODULE * buf_count)() + buf_size = ctypes.sizeof(buf) + if not ps_api.EnumProcessModulesEx( + h_process, ctypes.byref(buf), buf_size, + ctypes.byref(needed), LIST_MODULES_ALL): + raise OSError('EnumProcessModulesEx failed') + if buf_size >= needed.value: + break + buf_count = needed.value // (buf_size // buf_count) + + count = needed.value // (buf_size // buf_count) + h_modules = map(HMODULE, buf[:count]) + + # Loop through all the module headers and get the module path + buf = ctypes.create_unicode_buffer(MAX_PATH) + n_size = DWORD() + for h_module in h_modules: + + # Get the path of the current module + if not ps_api.GetModuleFileNameExW( + h_process, h_module, ctypes.byref(buf), + ctypes.byref(n_size)): + raise OSError('GetModuleFileNameEx failed') + filepath = buf.value + + # Store the module in cls_thread_locals._module if it is + # supported and selected + _get_module_info_from_path(filepath, prefixes, user_api, + _modules) + finally: + kernel_32.CloseHandle(h_process) + + return _modules + + +def _get_libc(): + """Load the lib-C for unix systems.""" + libc = _system_libraries.get("libc") + if libc is None: + libc_name = find_library("c") + if libc_name is None: # pragma: no cover + return None + libc = ctypes.CDLL(libc_name) + _system_libraries["libc"] = libc + return libc + + +def _get_windll(dll_name): + """Load a windows DLL""" + dll = _system_libraries.get(dll_name) + if dll is None: + dll = ctypes.WinDLL("{}.dll".format(dll_name)) + _system_libraries[dll_name] = dll + return dll + + +class threadpool_limits: + """Change the maximal number of threads that can be used in thread pools. + + This class can be used either as a function (the construction of this + object limits the number of threads) or as a context manager, in a `with` + block. + + Set the maximal number of threads that can be used in thread pools used in + the supported libraries to `limit`. This function works for libraries that + are already loaded in the interpreter and can be changed dynamically. + + The `limits` parameter can be either an integer or a dict to specify the + maximal number of thread that can be used in thread pools. If it is an + integer, sets the maximum number of thread to `limits` for each library + selected by `user_api`. If it is a dictionary `{{key: max_threads}}`, this + function sets a custom maximum number of thread for each `key` which can be + either a `user_api` or a `prefix` for a specific library. If None, this + function does not do anything. + + The `user_api` parameter selects particular APIs of libraries to limit. + Used only if `limits` is an int. If it is None, this function will apply to + all supported libraries. If it is "blas", it will limit only BLAS supported + libraries and if it is "openmp", only OpenMP supported libraries will be + limited. Note that the latter can affect the number of threads used by the + BLAS libraries if they rely on OpenMP. + """ + def __init__(self, limits=None, user_api=None): + self._user_api = _ALL_USER_APIS if user_api is None else [user_api] + + if limits is not None: + self._original_limits = _set_threadpool_limits( + limits=limits, user_api=user_api) + else: + self._original_limits = None + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.unregister() + + def unregister(self): + if self._original_limits is not None: + for module in self._original_limits: + module['set_num_threads'](module['num_threads']) + + def get_original_num_threads(self): + original_limits = self._original_limits or threadpool_info() + + num_threads = {} + warning_apis = [] + + for user_api in self._user_api: + limits = [module['num_threads'] for module in original_limits + if module['user_api'] == user_api] + limits = set(limits) + n_limits = len(limits) + + if n_limits == 1: + limit = limits.pop() + elif n_limits == 0: + limit = None + else: + limit = min(limits) + warning_apis.append(user_api) + + num_threads[user_api] = limit + + if warning_apis: + warnings.warn("Multiple value possible for following user apis: " + + ', '.join(warning_apis) + ". Returning the minimum.") + + return num_threads diff --git a/sklearn/externals/vendor_threadpoolctl.sh b/sklearn/externals/vendor_threadpoolctl.sh new file mode 100755 index 0000000000000..5a4eed62e368b --- /dev/null +++ b/sklearn/externals/vendor_threadpoolctl.sh @@ -0,0 +1,30 @@ +#!/bin/sh +# Script to do a local install of threadpoolctl +set +x +export LC_ALL=C +INSTALL_FOLDER=threadpoolctl_install +rm -rf _threadpoolctl.py $INSTALL_FOLDER 2> /dev/null +if [ -z "$1" ] +then + # Grab the latest stable release from PyPI + THREADPOOLCTL=threadpoolctl +else + THREADPOOLCTL=$1 +fi +pip install --no-cache $THREADPOOLCTL --target $INSTALL_FOLDER +cp $INSTALL_FOLDER/threadpoolctl.py _threadpoolctl.py +rm -rf $INSTALL_FOLDER + +# Needed to rewrite the doctests +# Note: BSD sed -i needs an argument unders OSX +# so first renaming to .bak and then deleting backup files +#find loky -name "*.py" | xargs sed -i.bak "s/from loky/from joblib.externals.loky/" +#find loky -name "*.bak" | xargs rm + +#for f in $(git grep -l "cloudpickle" loky); do +# echo $f; +# sed -i 's/import cloudpickle/from joblib.externals import cloudpickle/' $f +# sed -i 's/from cloudpickle import/from joblib.externals.cloudpickle import/' $f +# done + +# sed -i "s/loky.backend.popen_loky/joblib.externals.loky.backend.popen_loky/" loky/backend/popen_loky_posix.py From 6cb945b204c1041613b393403ff41fd05e6c0d00 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 16 Sep 2019 11:50:31 +0200 Subject: [PATCH 083/163] remove _clibs --- sklearn/cluster/k_means_.py | 6 +- sklearn/utils/_clibs.py | 411 ------------------------------ sklearn/utils/tests/test_clibs.py | 114 --------- 3 files changed, 3 insertions(+), 528 deletions(-) delete mode 100644 sklearn/utils/_clibs.py delete mode 100644 sklearn/utils/tests/test_clibs.py diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 270c2003a2531..ef6c1fbc05114 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -27,7 +27,7 @@ from ..utils import check_random_state from ..utils.validation import check_is_fitted, _check_sample_weight from ..utils.validation import FLOAT_DTYPES -from ..utils._clibs import thread_limits_context +from ..externals._threadpoolctl import threadpool_limits from ..exceptions import ConvergenceWarning from ._k_means import _inertia_dense from ._k_means import _inertia_sparse @@ -519,7 +519,7 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, Number of iterations run. """ random_state = check_random_state(random_state) - sample_weight = _check_sample_weight(X, sample_weight) + sample_weight = _check_normalize_sample_weight(sample_weight, X) # init centers = _init_centroids(X, n_clusters, init, random_state=random_state, @@ -977,7 +977,7 @@ def fit(self, X, y=None, sample_weight=None): # limit number of threads in second level of nested parallelism # (i.e. BLAS) to avoid oversubsciption. - with thread_limits_context(limits=1, subset="blas"): + with threadpool_limits(limits=1, user_api="blas"): for seed in seeds: # run a k-means once labels, inertia, centers, n_iter_ = kmeans_single( diff --git a/sklearn/utils/_clibs.py b/sklearn/utils/_clibs.py deleted file mode 100644 index 0fab3924cf737..0000000000000 --- a/sklearn/utils/_clibs.py +++ /dev/null @@ -1,411 +0,0 @@ -""" -This module provides utilities to load C-libraries that relies on thread -pools and limit the maximal number of thread that can be used. -""" - -# This code is adapted from code by Thomas Moreau available at -# https://github.com/tomMoral/loky - - -import sys -import os -import threading -import ctypes -from ctypes.util import find_library -from contextlib import contextmanager as contextmanager - - -# Structure to cast the info on dynamically loaded library. See -# https://linux.die.net/man/3/dl_iterate_phdr for more details. -UINT_SYSTEM = ctypes.c_uint64 if sys.maxsize > 2**32 else ctypes.c_uint32 -UINT_HALF_SYSTEM = ctypes.c_uint32 if sys.maxsize > 2**32 else ctypes.c_uint16 - - -class dl_phdr_info(ctypes.Structure): - _fields_ = [ - ("dlpi_addr", UINT_SYSTEM), # Base address of object - ("dlpi_name", ctypes.c_char_p), # path to the library - ("dlpi_phdr", ctypes.c_void_p), # pointer on dlpi_headers - ("dlpi_phnum", UINT_HALF_SYSTEM) # number of element in dlpi_phdr - ] - - -class _CLibsWrapper: - # Wrapper around classic C-libraries for scientific computations to set and - # get the maximum number of threads they are allowed to used for inner - # parallelism. - - # Supported C-libraries for this wrapper, index with their name. The items - # hold the name of the library file and the functions to call. - SUPPORTED_CLIBS = { - "openmp_intel": ( - "libiomp", "omp_set_num_threads", "omp_get_max_threads"), - "openmp_gnu": ( - "libgomp", "omp_set_num_threads", "omp_get_max_threads"), - "openmp_llvm": ( - "libomp", "omp_set_num_threads", "omp_get_max_threads"), - "openmp_win32": ( - "vcomp", "omp_set_num_threads", "omp_get_max_threads"), - "openblas": ( - "libopenblas", "openblas_set_num_threads", - "openblas_get_num_threads"), - "mkl": ( - "libmkl_rt", "MKL_Set_Num_Threads", "MKL_Get_Max_Threads"), - "mkl_win32": ( - "mkl_rt", "MKL_Set_Num_Threads", "MKL_Get_Max_Threads")} - - cls_thread_locals = threading.local() - - def __init__(self): - self._load() - - def _load(self): - for clib, (module_name, _, _) in self.SUPPORTED_CLIBS.items(): - setattr(self, clib, self._load_lib(module_name)) - - def _unload(self): - for clib, (module_name, _, _) in self.SUPPORTED_CLIBS.items(): - delattr(self, clib) - - def set_thread_limits(self, limits=1, subset=None): - """Limit maximal number of threads used by supported C-libraries""" - if isinstance(limits, int): - if subset in ("all", None): - clibs = self.SUPPORTED_CLIBS.keys() - elif subset == "blas": - clibs = ("openblas", "mkl", "mkl_win32") - elif subset == "openmp": - clibs = (c for c in self.SUPPORTED_CLIBS if "openmp" in c) - else: - raise ValueError("subset must be either 'all', 'blas' or " - "'openmp'. Got {} instead.".format(subset)) - limits = {clib: limits for clib in clibs} - - if not isinstance(limits, dict): - raise TypeError("limits must either be an int or a dict. Got {} " - "instead".format(type(limits))) - - dynamic_threadpool_size = {} - self._load() - for clib, (_, _set, _) in self.SUPPORTED_CLIBS.items(): - if clib in limits: - module = getattr(self, clib, None) - if module is not None: - _set = getattr(module, _set) - _set(limits[clib]) - dynamic_threadpool_size[clib] = True - else: - dynamic_threadpool_size[clib] = False - else: - dynamic_threadpool_size[clib] = False - self._unload() - return dynamic_threadpool_size - - def get_thread_limits(self): - """Return maximal number of threads available for supported C-libraries - """ - limits = {} - self._load() - for clib, (_, _, _get) in self.SUPPORTED_CLIBS.items(): - module = getattr(self, clib, None) - if module is not None: - _get = getattr(module, _get) - limits[clib] = _get() - else: - limits[clib] = None - self._unload() - return limits - - def get_openblas_version(self): - module = getattr(self, "openblas", None) - if module is not None: - get_config = getattr(module, "openblas_get_config") - get_config.restype = ctypes.c_char_p - config = get_config().split() - if config[0] == b"OpenBLAS": - return config[1].decode('utf-8') - return - return - - def _load_lib(self, module_name): - """Return a binder on module_name by looping through loaded libraries - """ - if sys.platform == "darwin": - return self._find_with_clibs_dyld(module_name) - elif sys.platform == "win32": - return self._find_with_clibs_enum_process_module_ex(module_name) - return self._find_with_clibs_dl_iterate_phdr(module_name) - - def _find_with_clibs_dl_iterate_phdr(self, module_name): - """Return a binder on module_name by looping through loaded libraries - - This function is expected to work on POSIX system only. - This code is adapted from code by Intel developper @anton-malakhov - available at https://github.com/IntelPython/smp - - Copyright (c) 2017, Intel Corporation published under the BSD 3-Clause - license - """ - self.cls_thread_locals._module_path = None - - libc = self._get_libc() - if not hasattr(libc, "dl_iterate_phdr"): - return - - # Callback function for `dl_iterate_phdr` which is called for every - # module loaded in the current process until it returns 1. - def match_module_callback(info, size, module_name): - - # recast the name of the module as a string - module_name = ctypes.string_at(module_name).decode('utf-8') - - # Get the name of the current library - module_path = info.contents.dlpi_name - - # If the current library is the one we are looking for, store the - # path and return 1 to stop the loop in `dl_iterate_phdr`. - if module_path: - module_path = module_path.decode("utf-8") - if os.path.basename(module_path).startswith(module_name): - self.cls_thread_locals._module_path = module_path - return 1 - return 0 - - c_func_signature = ctypes.CFUNCTYPE( - ctypes.c_int, # Return type - ctypes.POINTER(dl_phdr_info), ctypes.c_size_t, ctypes.c_char_p) - c_match_module_callback = c_func_signature(match_module_callback) - - data = ctypes.c_char_p(module_name.encode('utf-8')) - res = libc.dl_iterate_phdr(c_match_module_callback, data) - if res == 1: - return ctypes.CDLL(self.cls_thread_locals._module_path) - - def _find_with_clibs_dyld(self, module_name): - """Return a binder on module_name by looping through loaded libraries - - This function is expected to work on OSX system only - """ - libc = self._get_libc() - if not hasattr(libc, "_dyld_image_count"): - return - - found_module_path = None - - n_dyld = libc._dyld_image_count() - libc._dyld_get_image_name.restype = ctypes.c_char_p - - for i in range(n_dyld): - module_path = ctypes.string_at(libc._dyld_get_image_name(i)) - module_path = module_path.decode("utf-8") - if os.path.basename(module_path).startswith(module_name): - found_module_path = module_path - - if found_module_path: - return ctypes.CDLL(found_module_path) - - def _find_with_clibs_enum_process_module_ex(self, module_name): - """Return a binder on module_name by looping through loaded libraries - - This function is expected to work on windows system only. - This code is adapted from code by Philipp Hagemeister @phihag available - at https://stackoverflow.com/questions/17474574 - """ - from ctypes.wintypes import DWORD, HMODULE, MAX_PATH - - PROCESS_QUERY_INFORMATION = 0x0400 - PROCESS_VM_READ = 0x0010 - - LIST_MODULES_ALL = 0x03 - - Psapi = self._get_windll('Psapi') - Kernel32 = self._get_windll('kernel32') - - hProcess = Kernel32.OpenProcess( - PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, - False, os.getpid()) - if not hProcess: - raise OSError('Could not open PID %s' % os.getpid()) - - found_module_path = None - try: - buf_count = 256 - needed = DWORD() - # Grow the buffer until it becomes large enough to hold all the - # module headers - while True: - buf = (HMODULE * buf_count)() - buf_size = ctypes.sizeof(buf) - if not Psapi.EnumProcessModulesEx( - hProcess, ctypes.byref(buf), buf_size, - ctypes.byref(needed), LIST_MODULES_ALL): - raise OSError('EnumProcessModulesEx failed') - if buf_size >= needed.value: - break - buf_count = needed.value // (buf_size // buf_count) - - count = needed.value // (buf_size // buf_count) - hModules = map(HMODULE, buf[:count]) - - # Loop through all the module headers and get the module file name - buf = ctypes.create_unicode_buffer(MAX_PATH) - nSize = DWORD() - for hModule in hModules: - if not Psapi.GetModuleFileNameExW( - hProcess, hModule, ctypes.byref(buf), - ctypes.byref(nSize)): - raise OSError('GetModuleFileNameEx failed') - module_path = buf.value - module_basename = os.path.basename(module_path).lower() - if module_basename.startswith(module_name): - found_module_path = module_path - finally: - Kernel32.CloseHandle(hProcess) - - if found_module_path: - return ctypes.CDLL(found_module_path) - - def _get_libc(self): - if not hasattr(self, "libc"): - libc_name = find_library("c") - if libc_name is None: - self.libc = None - self.libc = ctypes.CDLL(libc_name) - - return self.libc - - def _get_windll(self, dll_name): - if not hasattr(self, dll_name): - setattr(self, dll_name, ctypes.WinDLL("{}.dll".format(dll_name))) - - return getattr(self, dll_name) - - -_clibs_wrapper = None - - -def _get_wrapper(reload_clib=False): - """Helper function to only create one wrapper per thread.""" - global _clibs_wrapper - if _clibs_wrapper is None: - _clibs_wrapper = _CLibsWrapper() - if reload_clib: - _clibs_wrapper._load() - return _clibs_wrapper - - -def set_thread_limits(limits=1, subset=None, reload_clib=False): - """Limit the number of threads available for threadpools in supported C-lib - - Set the maximal number of thread that can be used in thread pools used in - the supported C-libraries. This function works for libraries that are - already loaded in the interpreter and can be changed dynamically. - - Parameters - ---------- - limits : int or dict, (default=1) - Maximum number of thread that can be used in thread pools - - If int, sets the maximum number of thread to `limits` for each C-lib - selected by `subset`. - - If dict(supported_libraries: max_threads), sets a custom maximum number - of thread for each C-lib. - - subset : string or None, optional (default="all") - Subset of C-libs to limit. Used only if `limits` is an int - - "all" : limit all supported C-libs. - - "blas" : limit only BLAS supported C-libs. - - "openmp" : limit only OpenMP supported C-libs. It can affect the number - of threads used by the BLAS C-libs if they rely on OpenMP. - - reload_clib : bool, (default=False) - If `reload_clib` is `True`, first loop through the loaded libraries to - ensure that this function is called on all available libraries. - - Returns - ------- - dynamic_threadpool_size : dict - contains pairs `('clib': boolean)` which are True if `clib` have been - found and can be used to scale the maximal number of threads - dynamically. - """ - wrapper = _get_wrapper(reload_clib) - return wrapper.set_thread_limits(limits, subset) - - -def get_thread_limits(reload_clib=True): - """Return maximal thread number for threadpools in supported C-lib - - Parameters - ---------- - reload_clib : bool, (default=True) - If `reload_clib` is `True`, first loop through the loaded libraries to - ensure that this function is called on all available libraries. - - Returns - ------- - thread_limits : dict - Contains the maximal number of threads that can be used in supported - libraries or None when the library is not available. The key of the - dictionary are "openmp_gnu", "openmp_intel", "openmp_win32", - "openmp_llvm", "openblas", "mkl" and "mkl_win32". - """ - wrapper = _get_wrapper(reload_clib) - return wrapper.get_thread_limits() - - -@contextmanager -def thread_limits_context(limits=1, subset=None): - """Context manager for C-libs thread limits - - Parameters - ---------- - limits : int or dict, (default=1) - Maximum number of thread that can be used in thread pools - - If int, sets the maximum number of thread to `limits` for each C-lib - selected by `subset`. - - If dict(supported_libraries: max_threads), sets a custom maximum number - of thread for each C-lib. - - subset : string or None, optional (default="all") - Subset of C-libs to limit. Used only if `limits` is an int - - "all" : limit all supported C-libs. - - "blas" : limit only BLAS supported C-libs. - - "openmp" : limit only OpenMP supported C-libs. It can affect the number - of threads used by the BLAS C-libs if they rely on OpenMP. - """ - old_limits = get_thread_limits() - set_thread_limits(limits=limits, subset=subset) - - try: - yield - finally: - set_thread_limits(limits=old_limits) - - -def get_openblas_version(reload_clib=True): - """Return the OpenBLAS version - - Parameters - ---------- - reload_clib : bool, (default=True) - If `reload_clib` is `True`, first loop through the loaded libraries to - ensure that this function is called on all available libraries. - - Returns - ------- - version : string or None - None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS - did not expose it's verion before that. - """ - wrapper = _get_wrapper(reload_clib) - return wrapper.get_openblas_version() diff --git a/sklearn/utils/tests/test_clibs.py b/sklearn/utils/tests/test_clibs.py deleted file mode 100644 index 43aad0d8666a8..0000000000000 --- a/sklearn/utils/tests/test_clibs.py +++ /dev/null @@ -1,114 +0,0 @@ -import os - -import pytest - -from sklearn.utils.testing import SkipTest -from sklearn.utils._clibs import (get_thread_limits, set_thread_limits, - get_openblas_version, thread_limits_context, - _CLibsWrapper) - - -SKIP_OPENBLAS = get_openblas_version() is None - - -def test_openmp_enabled(): - # Check that an OpenMP library is loaded - limits = get_thread_limits() - - assert not all([lib is None for lib in [limits['openmp_llvm'], - limits['openmp_gnu'], - limits['openmp_win32'], - limits['openmp_intel']]]) - - -@pytest.mark.parametrize("clib", _CLibsWrapper.SUPPORTED_CLIBS) -def test_set_thread_limits_dict(clib): - # Check that the number of threads used by the multithreaded C-libs can be - # modified dynamically. - - if clib == "openblas" and SKIP_OPENBLAS: - raise SkipTest("Possible bug in getting maximum number of threads with" - " OpenBLAS < 0.2.16 and OpenBLAS does not expose it's " - "version before 0.3.4.") - - old_limits = get_thread_limits() - - if old_limits[clib] is not None: - dynamic_scaling = set_thread_limits(limits={clib: 1}) - assert get_thread_limits()[clib] == 1 - assert dynamic_scaling[clib] - - set_thread_limits(limits={clib: 3}) - new_limits = get_thread_limits() - assert new_limits[clib] in (3, os.cpu_count(), os.cpu_count() / 2) - - set_thread_limits(limits=old_limits) - new_limits = get_thread_limits() - assert new_limits[clib] == old_limits[clib] - - -@pytest.mark.parametrize("subset", ("all", "blas", "openmp")) -def test_set_thread_limits_subset(subset): - # Check that the number of threads used by the multithreaded C-libs can be - # modified dynamically. - - if subset == "all": - clibs = list(_CLibsWrapper.SUPPORTED_CLIBS.keys()) - elif subset == "blas": - clibs = ["openblas", "mkl", "mkl_win32"] - elif subset == "openmp": - clibs = list(c for c in _CLibsWrapper.SUPPORTED_CLIBS if "openmp" in c) - - if SKIP_OPENBLAS and "openblas" in clibs: - clibs.remove("openblas") - - old_limits = get_thread_limits() - - dynamic_scaling = set_thread_limits(limits=1, subset=subset) - new_limits = get_thread_limits() - for clib in clibs: - if old_limits[clib] is not None: - assert new_limits[clib] == 1 - assert dynamic_scaling[clib] - - set_thread_limits(limits=3, subset=subset) - new_limits = get_thread_limits() - for clib in clibs: - if old_limits[clib] is not None: - assert new_limits[clib] in (3, os.cpu_count(), os.cpu_count() / 2) - - set_thread_limits(limits=old_limits) - new_limits = get_thread_limits() - for clib in clibs: - if old_limits[clib] is not None: - assert new_limits[clib] == old_limits[clib] - - -def test_set_thread_limits_bad_input(): - # Check that appropriate errors are raised for invalid arguments - - with pytest.raises(ValueError, - match="subset must be either 'all', 'blas' " - "or 'openmp'"): - set_thread_limits(limits=1, subset="wrong") - - with pytest.raises(TypeError, - match="limits must either be an int or a dict"): - set_thread_limits(limits=(1, 2, 3)) - - -def test_thread_limit_context(): - old_limits = get_thread_limits() - - with thread_limits_context(limits=1): - limits = get_thread_limits() - if SKIP_OPENBLAS: - del limits["openblas"] - - for clib in limits: - if old_limits[clib] is None: - assert limits[clib] is None - else: - assert limits[clib] == 1 - - assert get_thread_limits() == old_limits From 6dd45252d56bf256c9c737e4db1fc5ce8c279a49 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 16 Sep 2019 11:57:18 +0200 Subject: [PATCH 084/163] fix merge mistakes --- sklearn/cluster/k_means_.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index ef6c1fbc05114..e79bdd7325d2f 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -378,7 +378,6 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, Number of iterations run. """ random_state = check_random_state(random_state) - sample_weight = _check_normalize_sample_weight(sample_weight, X) # init @@ -602,7 +601,7 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers, n_jobs=1): n_samples = X.shape[0] n_clusters = centers.shape[0] - sample_weight = _check_sample_weight(X, sample_weight) + sample_weight = _check_normalize_sample_weight(sample_weight, X) labels = np.full(n_samples, -1, dtype=np.int32) weight_in_clusters = np.zeros(n_clusters, dtype=centers.dtype) center_shift = np.zeros_like(weight_in_clusters) From 0278f677af7d58319e3690698242c2df3eb7cd79 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 16 Sep 2019 13:47:26 +0200 Subject: [PATCH 085/163] cln --- sklearn/cluster/_k_means.pyx | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 1304af86e83f4..87ab781988ee2 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -171,6 +171,7 @@ cdef void _relocate_empty_clusters_dense(np.ndarray[floating, ndim=2, mode='c'] weight_in_clusters[new_cluster_id] = weight weight_in_clusters[old_cluster_id] -= weight + cdef void _relocate_empty_clusters_sparse(floating[::1] X_data, int[::1] X_indices, int[::1] X_indptr, From aa8eebaff6f31693c335433cd9cca2cb9d5ff513 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Tue, 17 Sep 2019 15:43:09 +0200 Subject: [PATCH 086/163] revert appveyor modifs --- appveyor.yml | 4 +--- build_tools/appveyor/requirements.txt | 3 --- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index f766b4efb30b0..a75281522f7ba 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -85,7 +85,7 @@ test_script: } else { $env:PYTEST_ARGS = "" } - - "pytest --showlocals --durations=20 %PYTEST_ARGS% --pyargs --cov=sklearn sklearn" + - "pytest --showlocals --durations=20 %PYTEST_ARGS% --pyargs sklearn" # Move back to the project folder - cd "../scikit-learn" @@ -94,8 +94,6 @@ artifacts: - path: dist\* on_success: - - "cp ../empty_folder/.coverage ." - - codecov # Upload the generated wheel package to Rackspace - "python -m wheelhouse_uploader upload --local-folder=dist sklearn-windows-wheels" diff --git a/build_tools/appveyor/requirements.txt b/build_tools/appveyor/requirements.txt index 40ddc39003e27..1a2feca5c6b6b 100644 --- a/build_tools/appveyor/requirements.txt +++ b/build_tools/appveyor/requirements.txt @@ -2,9 +2,6 @@ numpy scipy cython pytest -pytest-cov -coverage -codecov wheel wheelhouse_uploader pillow From f1231f55f29e9171c8421d3fc8c02dcc8ec068f2 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Tue, 17 Sep 2019 17:00:20 +0200 Subject: [PATCH 087/163] improve docstring --- sklearn/utils/openmp_helpers.pyx | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/openmp_helpers.pyx b/sklearn/utils/openmp_helpers.pyx index dd035fc1ef481..d193ab821828d 100644 --- a/sklearn/utils/openmp_helpers.pyx +++ b/sklearn/utils/openmp_helpers.pyx @@ -7,13 +7,14 @@ cpdef _openmp_effective_n_threads(n_threads=None): """Determine the effective number of threads used for parallel OpenMP calls - For ``n_threads = None``, returns the minimum between - openmp.omp_get_max_threads() and joblib.effective_n_jobs(-1). + ``openmp.omp_get_max_threads()`` and ``joblib.cpu_count()``. The result of ``omp_get_max_threads`` can be influenced by environment variable ``OMP_NUM_THREADS`` or at runtime by ``omp_set_num_threads``. - For ``n_threads > 0``, use this as the maximal number of threads for parallel OpenMP calls. - For ``n_threads < 0``, use the maximal number of threads minus - ``|n_threads + 1|``. + ``|n_threads + 1|``. In particular ``n_threads=-1`` will use as many + threads as there are available cores on the machine. - Raise a ValueError for ``n_threads = 0``. If scikit-learn is built without OpenMP support, always return 1. From f23ccbb5e11f411bdee7035284ff767a26beb845 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 18 Sep 2019 15:27:29 +0200 Subject: [PATCH 088/163] test deprecated precompute distance --- sklearn/cluster/k_means_.py | 24 +++++++++++++----------- sklearn/cluster/tests/test_k_means.py | 12 ++++++++++++ 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index e79bdd7325d2f..c0ad8dc70bd3b 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -182,7 +182,7 @@ def _check_normalize_sample_weight(sample_weight, X): def k_means(X, n_clusters, sample_weight=None, init='k-means++', - precompute_distances='not-used', n_init=10, max_iter=300, + precompute_distances='deprecated', n_init=10, max_iter=300, verbose=False, tol=1e-4, random_state=None, copy_x=True, n_jobs=None, algorithm="auto", return_n_iter=False): """K-means clustering algorithm. @@ -230,9 +230,10 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', True : always precompute distances False : never precompute distances - .. deprecated:: 0.21 - 'precompute_distances' was deprecated in version 0.21 and will be - removed in 0.23. + + .. deprecated:: 0.22 + 'precompute_distances' was deprecated in version 0.22 and will be + removed in 0.24. n_init : int, (default=10) Number of time the k-means algorithm will be run with different @@ -753,9 +754,10 @@ class KMeans(TransformerMixin, ClusterMixin, BaseEstimator): True : always precompute distances False : never precompute distances - .. deprecated:: 0.21 - 'precompute_distances' was deprecated in version 0.21 and will be - removed in 0.23. + + .. deprecated:: 0.22 + 'precompute_distances' was deprecated in version 0.22 and will be + removed in 0.24. verbose : int, optional (default=0) Verbosity mode. @@ -853,7 +855,7 @@ class KMeans(TransformerMixin, ClusterMixin, BaseEstimator): """ def __init__(self, n_clusters=8, init='k-means++', n_init=10, - max_iter=300, tol=1e-4, precompute_distances='not-used', + max_iter=300, tol=1e-4, precompute_distances='deprecated', verbose=0, random_state=None, copy_x=True, n_jobs=None, algorithm='auto'): @@ -900,9 +902,9 @@ def fit(self, X, y=None, sample_weight=None): """ random_state = check_random_state(self.random_state) - if self.precompute_distances != 'not-used': - warnings.warn("'precompute_distances' was deprecated in version" - "0.21 and will be removed in 0.23.", + if self.precompute_distances != 'deprecated': + warnings.warn("'precompute_distances' was deprecated in version " + "0.22 and will be removed in 0.24.", DeprecationWarning) n_init = self.n_init diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 4e4793583fbae..25eba8b544616 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -987,3 +987,15 @@ def test_result_of_kmeans_equal_in_diff_n_jobs(): result_1 = KMeans(n_clusters=3, random_state=0, n_jobs=1).fit(X).labels_ result_2 = KMeans(n_clusters=3, random_state=0, n_jobs=2).fit(X).labels_ assert_array_equal(result_1, result_2) + + +@pytest.mark.parametrize("precompute_distances", ["auto", False, True]) +def test_precompute_distance_deprecated(precompute_distances): + # FIXME: remove in 0.24 + depr_msg = "'precompute_distances' was deprecated in version 0.22" + X, _ = make_blobs(n_samples=100, n_features=2, centers=2, random_state=0) + kmeans = KMeans(n_clusters=2, n_init=1, init='random', random_state=0, + precompute_distances=precompute_distances) + + with pytest.warns(DeprecationWarning, match=depr_msg): + kmeans.fit(X) From e2dd616c53428800c18d7eac1a7ae828c306b185 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 18 Sep 2019 15:37:02 +0200 Subject: [PATCH 089/163] test elkan + 1 cluster warning --- sklearn/cluster/tests/test_k_means.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 25eba8b544616..26b7a35b896d2 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -993,9 +993,20 @@ def test_result_of_kmeans_equal_in_diff_n_jobs(): def test_precompute_distance_deprecated(precompute_distances): # FIXME: remove in 0.24 depr_msg = "'precompute_distances' was deprecated in version 0.22" - X, _ = make_blobs(n_samples=100, n_features=2, centers=2, random_state=0) + X, _ = make_blobs(n_samples=10, n_features=2, centers=2, random_state=0) kmeans = KMeans(n_clusters=2, n_init=1, init='random', random_state=0, precompute_distances=precompute_distances) with pytest.warns(DeprecationWarning, match=depr_msg): kmeans.fit(X) + + +def test_warning_elkan_1_cluster(): + X, _ = make_blobs(n_samples=10, n_features=2, centers=1, random_state=0) + kmeans = KMeans(n_clusters=1, n_init=1, init='random', random_state=0, + algorithm='elkan') + + with pytest.warns(RuntimeWarning, + match="algorithm='elkan' doesn't make sense for a single" + " cluster"): + kmeans.fit(X) From 4e2ff78c5518a48ba60c833e8f773d80a3603fb6 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 18 Sep 2019 16:16:11 +0200 Subject: [PATCH 090/163] test error wrong algo --- sklearn/cluster/tests/test_k_means.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 26b7a35b896d2..2eab6d6fece24 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -1010,3 +1010,13 @@ def test_warning_elkan_1_cluster(): match="algorithm='elkan' doesn't make sense for a single" " cluster"): kmeans.fit(X) + + +def test_error_wrong_algorithm(): + X, _ = make_blobs(n_samples=10, n_features=2, centers=2, random_state=0) + kmeans = KMeans(n_clusters=2, n_init=1, init='random', random_state=0, + algorithm='wrong') + + with pytest.raises(ValueError, + match="Algorithm must be 'auto', 'full' or 'elkan'"): + kmeans.fit(X) From b9af0a6addef80308e408545be3cc56dba035d0b Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 18 Sep 2019 22:23:28 +0200 Subject: [PATCH 091/163] Make it explicit that LOKY_MAX_CPU_COUNT can impact _openmp_effective_n_threads --- sklearn/utils/openmp_helpers.pyx | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/utils/openmp_helpers.pyx b/sklearn/utils/openmp_helpers.pyx index d193ab821828d..0481f5104a2e6 100644 --- a/sklearn/utils/openmp_helpers.pyx +++ b/sklearn/utils/openmp_helpers.pyx @@ -10,6 +10,9 @@ cpdef _openmp_effective_n_threads(n_threads=None): ``openmp.omp_get_max_threads()`` and ``joblib.cpu_count()``. The result of ``omp_get_max_threads`` can be influenced by environment variable ``OMP_NUM_THREADS`` or at runtime by ``omp_set_num_threads``. + The value returned by ``joblib.cpu_count()`` can be controlled by + setting the ``LOKY_MAX_CPU_COUNT`` environment variable (instead of + returning the number of available CPU cores). - For ``n_threads > 0``, use this as the maximal number of threads for parallel OpenMP calls. - For ``n_threads < 0``, use the maximal number of threads minus From d9ea936b6641de2649c8bc29d3859521e2e81216 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 18 Sep 2019 22:38:44 +0200 Subject: [PATCH 092/163] Use _openmp_effective_n_threads in KMeans.fit --- sklearn/cluster/k_means_.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index c0ad8dc70bd3b..3d33f4bb8c39f 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -27,6 +27,7 @@ from ..utils import check_random_state from ..utils.validation import check_is_fitted, _check_sample_weight from ..utils.validation import FLOAT_DTYPES +from ..utils.openmp_helpers import _openmp_effective_n_threads from ..externals._threadpoolctl import threadpool_limits from ..exceptions import ConvergenceWarning from ._k_means import _inertia_dense @@ -974,7 +975,7 @@ def fit(self, X, y=None, sample_weight=None): # seeds for the initializations of the kmeans runs. seeds = random_state.randint(np.iinfo(np.int32).max, size=n_init) - n_jobs = 1 if self.n_jobs is None else self.n_jobs + n_jobs = _openmp_effective_n_threads(self.n_jobs) # limit number of threads in second level of nested parallelism # (i.e. BLAS) to avoid oversubsciption. From 851b05fbbdb3a7f2e5346818c5761acfa91ef134 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 19 Sep 2019 16:54:11 +0200 Subject: [PATCH 093/163] cln --- sklearn/cluster/_k_means.pyx | 4 +--- sklearn/utils/sparsefuncs_fast.pyx | 12 ++++++------ sklearn/utils/tests/test_sparsefuncs.py | 2 +- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 87ab781988ee2..e781c0c1facc1 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -1,4 +1,4 @@ -# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True, initializedcheck=False +# cython: profile=True, boundscheck=False, wraparound=False, cdivision=True # Profiling is enabled by default as the overhead does not seem to be # measurable on this specific use case. @@ -7,8 +7,6 @@ # Lars Buitinck # # License: BSD 3 clause -# -# cython: boundscheck=False, wraparound=False, cdivision=True import numpy as np cimport numpy as np diff --git a/sklearn/utils/sparsefuncs_fast.pyx b/sklearn/utils/sparsefuncs_fast.pyx index 211d5dcf074c9..f4da67f1e63d0 100644 --- a/sklearn/utils/sparsefuncs_fast.pyx +++ b/sklearn/utils/sparsefuncs_fast.pyx @@ -31,17 +31,17 @@ def csr_row_norms(X): if X.dtype not in [np.float32, np.float64]: X = X.astype(np.float64) - norms = np.zeros(X.shape[0], dtype=X.data.dtype) - _csr_row_norms(X.data, X.shape, X.indices, X.indptr, norms) - - return norms + norms = np.empty(X.shape[0], dtype=X.data.dtype) + _csr_row_norms(X.data, X.shape, X.indices, X.indptr, out=norms) + return norms + def _csr_row_norms(np.ndarray[floating, ndim=1, mode="c"] X_data, shape, np.ndarray[integral, ndim=1, mode="c"] X_indices, np.ndarray[integral, ndim=1, mode="c"] X_indptr, - floating[::1] norms): + floating[::1] out): cdef: unsigned long long n_samples = shape[0] @@ -53,7 +53,7 @@ def _csr_row_norms(np.ndarray[floating, ndim=1, mode="c"] X_data, sum_ = 0.0 for j in range(X_indptr[i], X_indptr[i + 1]): sum_ += X_data[j] * X_data[j] - norms[i] = sum_ + out[i] = sum_ def csr_mean_variance_axis0(X): diff --git a/sklearn/utils/tests/test_sparsefuncs.py b/sklearn/utils/tests/test_sparsefuncs.py index 4fed6bdd395a3..cc23a47ede4a4 100644 --- a/sklearn/utils/tests/test_sparsefuncs.py +++ b/sklearn/utils/tests/test_sparsefuncs.py @@ -531,7 +531,7 @@ def test_inplace_normalize(): @pytest.mark.parametrize("dtype", [np.float32, np.float64]) def test_csr_row_norms(dtype): # checks that csr_row_norms returns the same output as - # scipy.sparse.linalg.norm, and that the dype is the same X's. + # scipy.sparse.linalg.norm, and that the dype is the same as X.dtype. X = sp.random(100, 10, format='csr', dtype=dtype) scipy_norms = sp.linalg.norm(X, axis=1)**2 From de02372b5839525b4e5f74985c6356ee13769516 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 19 Sep 2019 17:02:35 +0200 Subject: [PATCH 094/163] cln --- sklearn/externals/vendor_threadpoolctl.sh | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/sklearn/externals/vendor_threadpoolctl.sh b/sklearn/externals/vendor_threadpoolctl.sh index 5a4eed62e368b..321a006648d1d 100755 --- a/sklearn/externals/vendor_threadpoolctl.sh +++ b/sklearn/externals/vendor_threadpoolctl.sh @@ -14,17 +14,3 @@ fi pip install --no-cache $THREADPOOLCTL --target $INSTALL_FOLDER cp $INSTALL_FOLDER/threadpoolctl.py _threadpoolctl.py rm -rf $INSTALL_FOLDER - -# Needed to rewrite the doctests -# Note: BSD sed -i needs an argument unders OSX -# so first renaming to .bak and then deleting backup files -#find loky -name "*.py" | xargs sed -i.bak "s/from loky/from joblib.externals.loky/" -#find loky -name "*.bak" | xargs rm - -#for f in $(git grep -l "cloudpickle" loky); do -# echo $f; -# sed -i 's/import cloudpickle/from joblib.externals import cloudpickle/' $f -# sed -i 's/from cloudpickle import/from joblib.externals.cloudpickle import/' $f -# done - -# sed -i "s/loky.backend.popen_loky/joblib.externals.loky.backend.popen_loky/" loky/backend/popen_loky_posix.py From 09f9423cbcd8f4954932a08ae210cf9f2b261faa Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 30 Dec 2019 11:41:33 +0100 Subject: [PATCH 095/163] merge master --- .circleci/config.yml | 14 +- .codecov.yml | 2 + .coveragerc | 1 + .github/workflows/twitter.yml | 25 + .gitignore | 160 ++- .travis.yml | 2 +- MANIFEST.in | 1 - Makefile | 4 +- README.rst | 12 +- azure-pipelines.yml | 72 +- benchmarks/bench_glmnet.py | 2 +- benchmarks/bench_lasso.py | 2 +- benchmarks/bench_multilabel_metrics.py | 2 +- benchmarks/bench_plot_fastkmeans.py | 2 +- benchmarks/bench_plot_hierarchical.py | 85 ++ benchmarks/bench_plot_lasso_path.py | 2 +- benchmarks/bench_plot_nmf.py | 10 +- benchmarks/bench_plot_omp_lars.py | 2 +- benchmarks/bench_plot_randomized_svd.py | 5 +- benchmarks/bench_plot_svd.py | 2 +- benchmarks/bench_rcv1_logreg_convergence.py | 2 +- benchmarks/bench_sgd_regression.py | 2 +- benchmarks/bench_sparsify.py | 2 +- benchmarks/bench_text_vectorizers.py | 2 +- benchmarks/bench_tsne_mnist.py | 5 +- build_tools/azure/install.cmd | 8 +- build_tools/azure/install.sh | 66 +- build_tools/azure/posix-32.yml | 5 +- build_tools/azure/posix.yml | 4 +- build_tools/azure/windows.yml | 4 +- build_tools/circle/build_doc.sh | 79 +- build_tools/circle/build_test_pypy.sh | 13 +- .../circle/check_deprecated_properties.sh | 16 - .../circle/{flake8_diff.sh => linting.sh} | 16 + build_tools/circle/list_versions.py | 6 +- build_tools/generate_authors_table.py | 48 +- build_tools/travis/install.sh | 156 +-- conftest.py | 22 +- doc/Makefile | 4 +- doc/about.rst | 438 ++++-- doc/authors.rst | 96 +- doc/authors_emeritus.rst | 2 +- doc/conf.py | 108 +- doc/conftest.py | 8 +- doc/contents.rst | 24 + doc/data_transforms.rst | 1 + doc/developers/advanced_installation.rst | 437 +++--- doc/developers/contributing.rst | 87 +- doc/developers/develop.rst | 2 +- doc/developers/index.rst | 5 + doc/developers/maintainer.rst | 2 +- doc/developers/performance.rst | 9 +- doc/developers/plotting.rst | 90 ++ doc/developers/tips.rst | 12 +- doc/developers/utilities.rst | 8 +- doc/documentation.rst | 117 -- doc/faq.rst | 45 +- doc/getting_started.rst | 231 +++ doc/glossary.rst | 65 +- doc/governance.rst | 2 +- doc/images/anaconda-small.png | Bin 0 -> 11313 bytes doc/images/anaconda.png | Bin 0 -> 39373 bytes doc/images/axa-small.png | Bin 0 -> 11616 bytes doc/images/axa.png | Bin 0 -> 17847 bytes doc/images/bcg-small.png | Bin 0 -> 17039 bytes doc/images/bcg.png | Bin 0 -> 31049 bytes doc/images/bnp-small.png | Bin 0 -> 12497 bytes doc/images/bnp.png | Bin 0 -> 21156 bytes doc/images/cds-logo.png | Bin 6501 -> 13205 bytes doc/images/columbia-small.png | Bin 0 -> 1170 bytes doc/images/dataiku-small.png | Bin 0 -> 6101 bytes doc/images/dataiku.png | Bin 0 -> 9040 bytes doc/images/fnrs-logo-small.png | Bin 0 -> 1110 bytes doc/images/fujitsu-small.png | Bin 0 -> 6618 bytes doc/images/fujitsu.png | Bin 0 -> 18012 bytes doc/images/google-small.png | Bin 0 -> 4692 bytes doc/images/inria-logo.jpg | Bin 21107 -> 26245 bytes doc/images/inria-small.png | Bin 0 -> 7105 bytes doc/images/intel-small.png | Bin 0 -> 10935 bytes doc/images/intel.png | Bin 0 -> 7484 bytes doc/images/microsoft-small.png | Bin 0 -> 8047 bytes doc/images/microsoft.png | Bin 0 -> 10320 bytes doc/images/nvidia-small.png | Bin 0 -> 8070 bytes doc/images/nvidia.png | Bin 0 -> 10764 bytes doc/images/png-logo-inria-la-fondation.png | Bin 0 -> 6152 bytes doc/images/scikit-learn-logo-small.png | Bin 0 -> 5468 bytes doc/images/sloan_banner.png | Bin 22729 -> 29042 bytes doc/images/sloan_logo-small.png | Bin 0 -> 2236 bytes doc/images/sydney-stacked-small.png | Bin 0 -> 1728 bytes doc/images/telecom-small.png | Bin 0 -> 3779 bytes doc/includes/big_toc_css.rst | 44 +- doc/index.rst | 356 ----- doc/install.rst | 285 +++- doc/model_selection.rst | 1 + doc/modules/classes.rst | 194 +-- doc/modules/clustering.rst | 3 +- doc/modules/compose.rst | 34 +- doc/modules/computing.rst | 158 ++- doc/modules/cross_validation.rst | 37 +- doc/modules/decomposition.rst | 2 +- doc/modules/density.rst | 2 +- doc/modules/ensemble.rst | 176 ++- doc/modules/grid_search.rst | 17 + doc/modules/linear_model.rst | 51 +- doc/modules/model_evaluation.rst | 63 +- doc/modules/multiclass.rst | 197 ++- doc/modules/naive_bayes.rst | 34 + doc/modules/neighbors.rst | 109 +- doc/modules/partial_dependence.rst | 4 +- doc/modules/preprocessing.rst | 2 +- doc/modules/svm.rst | 9 +- doc/other_distributions.rst | 66 - doc/preface.rst | 10 +- doc/roadmap.rst | 102 +- doc/supervised_learning.rst | 1 + doc/templates/documentation.html | 14 + doc/templates/index.html | 255 ++++ doc/testimonials/testimonials.rst | 581 +++++--- .../scikit-learn-modern/javascript.html | 149 ++ doc/themes/scikit-learn-modern/layout.html | 130 ++ doc/themes/scikit-learn-modern/nav.html | 85 ++ doc/themes/scikit-learn-modern/search.html | 8 + .../scikit-learn-modern/static/css/theme.css | 1243 +++++++++++++++++ .../static/css/vendor/bootstrap.min.css | 6 + .../static/js/searchtools.js | 595 ++++++++ .../static/js/vendor/bootstrap.min.js | 6 + doc/themes/scikit-learn-modern/theme.conf | 8 + .../scikit-learn/static/img/digicosme.png | Bin 11400 -> 18585 bytes doc/themes/scikit-learn/static/nature.css_t | 1 + doc/tune_toc.rst | 79 +- doc/tutorial/index.rst | 5 +- doc/tutorial/machine_learning_map/index.rst | 61 +- doc/unsupervised_learning.rst | 1 + doc/user_guide.rst | 5 + doc/visualizations.rst | 7 + doc/whats_new.rst | 27 +- doc/whats_new/_contributors.rst | 14 +- doc/whats_new/changelog_legend.inc | 12 + doc/whats_new/v0.20.rst | 2 +- doc/whats_new/v0.21.rst | 6 +- doc/whats_new/v0.22.rst | 687 ++++++++- doc/whats_new/v0.23.rst | 129 ++ .../plot_model_complexity_influence.py | 6 +- .../applications/plot_prediction_latency.py | 10 +- .../plot_species_distribution_modeling.py | 30 +- .../bicluster/plot_bicluster_newsgroups.py | 4 +- .../bicluster/plot_spectral_biclustering.py | 11 +- .../bicluster/plot_spectral_coclustering.py | 10 +- .../plot_digits_classification.py | 39 +- examples/cluster/plot_affinity_propagation.py | 2 +- .../cluster/plot_birch_vs_minibatchkmeans.py | 2 +- examples/cluster/plot_dbscan.py | 2 +- examples/cluster/plot_mean_shift.py | 2 +- examples/cluster/plot_mini_batch_kmeans.py | 2 +- .../plot_ward_structured_vs_unstructured.py | 2 +- examples/compose/plot_column_transformer.py | 8 +- examples/compose/plot_compare_reduction.py | 8 +- examples/compose/plot_digits_pipe.py | 11 +- .../decomposition/plot_beta_divergence.py | 2 +- .../decomposition/plot_faces_decomposition.py | 2 +- examples/ensemble/plot_stack_predictors.py | 123 ++ .../plot_feature_selection_pipeline.py | 4 +- examples/gaussian_process/plot_gpc.py | 2 +- .../plot_gpr_on_structured_data.py | 174 +++ .../inspection/plot_partial_dependence.py | 31 +- .../inspection/plot_permutation_importance.py | 7 +- ...t_permutation_importance_multicollinear.py | 12 +- .../plot_lasso_dense_vs_sparse_data.py | 2 +- examples/linear_model/plot_logistic_path.py | 7 +- .../linear_model/plot_sgd_early_stopping.py | 2 +- examples/linear_model/plot_sgd_penalties.py | 4 +- .../plot_sgd_separating_hyperplane.py | 2 +- ...sparse_logistic_regression_20newsgroups.py | 7 +- .../plot_sparse_logistic_regression_mnist.py | 2 +- examples/manifold/plot_compare_methods.py | 96 +- examples/manifold/plot_swissroll.py | 2 +- examples/manifold/plot_t_sne_perplexity.py | 4 +- .../model_selection/plot_confusion_matrix.py | 78 +- .../model_selection/plot_precision_recall.py | 26 +- .../model_selection/plot_randomized_search.py | 37 +- examples/model_selection/plot_roc.py | 4 +- .../plot_classifier_chain_yeast.py | 2 +- .../approximate_nearest_neighbors.py | 294 ++++ .../plot_caching_nearest_neighbors.py | 64 + examples/neighbors/plot_nca_illustration.py | 59 +- examples/neighbors/plot_species_kde.py | 29 +- .../plot_changed_only_pprint_parameter.py | 2 +- ...ot_partial_dependence_visualization_api.py | 137 ++ examples/plot_roc_curve_visualization_api.py | 2 +- .../plot_discretization_classification.py | 2 +- .../preprocessing/plot_map_data_to_normal.py | 6 +- examples/release_highlights/README.txt | 6 + .../plot_release_highlights_0_22_0.py | 264 ++++ .../plot_label_propagation_digits.py | 4 +- ...abel_propagation_digits_active_learning.py | 4 +- .../plot_label_propagation_structure.py | 4 +- .../plot_label_propagation_versus_svm_iris.py | 10 +- examples/tree/plot_cost_complexity_pruning.py | 2 +- maint_tools/check_pxd_in_installation.py | 59 + maint_tools/test_docstrings.py | 220 +++ setup.py | 73 +- sklearn/__init__.py | 24 +- sklearn/_build_utils/__init__.py | 142 +- sklearn/_build_utils/deprecated_modules.py | 323 +++++ sklearn/_build_utils/openmp_helpers.py | 134 +- sklearn/_build_utils/pre_build_helpers.py | 70 + sklearn/_distributor_init.py | 10 + sklearn/base.py | 105 +- sklearn/calibration.py | 12 +- sklearn/cluster/__init__.py | 20 +- ...opagation_.py => _affinity_propagation.py} | 31 +- .../{hierarchical.py => _agglomerative.py} | 123 +- .../cluster/{bicluster.py => _bicluster.py} | 70 +- sklearn/cluster/{birch.py => _birch.py} | 65 +- sklearn/cluster/{dbscan_.py => _dbscan.py} | 130 +- sklearn/cluster/_feature_agglomeration.py | 4 +- ...ierarchical.pyx => _hierarchical_fast.pyx} | 94 +- sklearn/cluster/_k_means_elkan.pyx | 12 +- .../{_k_means.pxd => _k_means_fast.pxd} | 0 .../{_k_means.pyx => _k_means_fast.pyx} | 4 +- sklearn/cluster/_k_means_lloyd.pyx | 6 +- sklearn/cluster/{k_means_.py => _kmeans.py} | 20 +- .../{mean_shift_.py => _mean_shift.py} | 178 +-- sklearn/cluster/{optics_.py => _optics.py} | 25 +- sklearn/cluster/{spectral.py => _spectral.py} | 33 +- sklearn/cluster/setup.py | 8 +- .../tests/test_affinity_propagation.py | 18 +- sklearn/cluster/tests/test_bicluster.py | 18 +- sklearn/cluster/tests/test_birch.py | 12 +- sklearn/cluster/tests/test_dbscan.py | 23 +- .../tests/test_feature_agglomeration.py | 4 +- sklearn/cluster/tests/test_hierarchical.py | 68 +- sklearn/cluster/tests/test_k_means.py | 54 +- sklearn/cluster/tests/test_mean_shift.py | 21 +- sklearn/cluster/tests/test_optics.py | 15 +- sklearn/cluster/tests/test_spectral.py | 32 +- sklearn/compose/__init__.py | 4 +- sklearn/compose/_column_transformer.py | 88 +- sklearn/compose/_target.py | 6 +- .../compose/tests/test_column_transformer.py | 106 +- sklearn/compose/tests/test_target.py | 6 +- sklearn/covariance/__init__.py | 17 +- ...ptic_envelope.py => _elliptic_envelope.py} | 0 ...ovariance_.py => _empirical_covariance.py} | 8 +- .../{graph_lasso_.py => _graph_lasso.py} | 5 +- ...st_covariance.py => _robust_covariance.py} | 2 +- ...k_covariance_.py => _shrunk_covariance.py} | 8 +- sklearn/covariance/tests/test_covariance.py | 8 +- .../tests/test_elliptic_envelope.py | 6 +- .../covariance/tests/test_graphical_lasso.py | 6 +- .../tests/test_robust_covariance.py | 6 +- sklearn/cross_decomposition/__init__.py | 6 +- .../cross_decomposition/{cca_.py => _cca.py} | 2 +- .../cross_decomposition/{pls_.py => _pls.py} | 89 +- sklearn/cross_decomposition/tests/test_pls.py | 12 +- sklearn/datasets/__init__.py | 92 +- sklearn/datasets/{base.py => _base.py} | 39 +- ...rnia_housing.py => _california_housing.py} | 48 +- sklearn/datasets/{covtype.py => _covtype.py} | 10 +- .../datasets/{kddcup99.py => _kddcup99.py} | 8 +- sklearn/datasets/{lfw.py => _lfw.py} | 4 +- .../{olivetti_faces.py => _olivetti_faces.py} | 10 +- sklearn/datasets/{openml.py => _openml.py} | 8 +- sklearn/datasets/{rcv1.py => _rcv1.py} | 12 +- ...les_generator.py => _samples_generator.py} | 72 +- ...ributions.py => _species_distributions.py} | 10 +- ...t_format.pyx => _svmlight_format_fast.pyx} | 0 ...light_format.py => _svmlight_format_io.py} | 12 +- ...ty_newsgroups.py => _twenty_newsgroups.py} | 37 +- sklearn/datasets/setup.py | 4 +- sklearn/datasets/tests/test_20news.py | 19 +- sklearn/datasets/tests/test_base.py | 12 +- .../datasets/tests/test_california_housing.py | 49 +- sklearn/datasets/tests/test_covtype.py | 2 +- sklearn/datasets/tests/test_kddcup99.py | 2 +- sklearn/datasets/tests/test_lfw.py | 18 +- sklearn/datasets/tests/test_olivetti_faces.py | 4 +- sklearn/datasets/tests/test_openml.py | 38 +- sklearn/datasets/tests/test_rcv1.py | 6 +- .../datasets/tests/test_samples_generator.py | 72 +- .../datasets/tests/test_svmlight_format.py | 6 +- sklearn/decomposition/__init__.py | 40 +- sklearn/decomposition/{base.py => _base.py} | 0 .../{cdnmf_fast.pyx => _cdnmf_fast.pyx} | 0 .../{dict_learning.py => _dict_learning.py} | 52 +- ...factor_analysis.py => _factor_analysis.py} | 2 + .../{fastica_.py => _fastica.py} | 248 ++-- ...incremental_pca.py => _incremental_pca.py} | 6 +- .../{kernel_pca.py => _kernel_pca.py} | 7 +- .../decomposition/{online_lda.py => _lda.py} | 10 +- sklearn/decomposition/{nmf.py => _nmf.py} | 39 +- .../{_online_lda.pyx => _online_lda_fast.pyx} | 0 sklearn/decomposition/{pca.py => _pca.py} | 51 +- .../{sparse_pca.py => _sparse_pca.py} | 4 +- .../{truncated_svd.py => _truncated_svd.py} | 10 +- sklearn/decomposition/setup.py | 8 +- .../decomposition/tests/test_dict_learning.py | 8 +- .../tests/test_factor_analysis.py | 8 +- sklearn/decomposition/tests/test_fastica.py | 27 +- .../tests/test_incremental_pca.py | 6 +- .../decomposition/tests/test_kernel_pca.py | 26 +- sklearn/decomposition/tests/test_nmf.py | 18 +- .../decomposition/tests/test_online_lda.py | 15 +- sklearn/decomposition/tests/test_pca.py | 11 +- .../decomposition/tests/test_sparse_pca.py | 9 +- .../decomposition/tests/test_truncated_svd.py | 2 +- sklearn/discriminant_analysis.py | 45 +- sklearn/dummy.py | 136 +- sklearn/ensemble/__init__.py | 41 +- sklearn/ensemble/{bagging.py => _bagging.py} | 79 +- sklearn/ensemble/{base.py => _base.py} | 150 +- sklearn/ensemble/{forest.py => _forest.py} | 448 ++++-- .../ensemble/{gradient_boosting.py => _gb.py} | 1056 +------------- .../_hist_gradient_boosting/binning.py | 6 +- .../gradient_boosting.py | 82 +- .../ensemble/_hist_gradient_boosting/loss.py | 2 +- .../_hist_gradient_boosting/splitting.pyx | 4 +- .../tests/test_gradient_boosting.py | 13 +- .../tests/test_splitting.py | 2 +- .../tests/test_warm_start.py | 45 +- sklearn/ensemble/{iforest.py => _iforest.py} | 58 +- sklearn/ensemble/_stacking.py | 659 +++++++++ sklearn/ensemble/{voting.py => _voting.py} | 154 +- ...weight_boosting.py => _weight_boosting.py} | 208 +-- sklearn/ensemble/partial_dependence.py | 441 ------ sklearn/ensemble/tests/test_bagging.py | 15 +- sklearn/ensemble/tests/test_base.py | 4 +- sklearn/ensemble/tests/test_common.py | 172 +++ sklearn/ensemble/tests/test_forest.py | 113 +- .../ensemble/tests/test_gradient_boosting.py | 48 +- sklearn/ensemble/tests/test_iforest.py | 20 +- .../ensemble/tests/test_partial_dependence.py | 277 ---- sklearn/ensemble/tests/test_stacking.py | 479 +++++++ sklearn/ensemble/tests/test_voting.py | 138 +- .../ensemble/tests/test_weight_boosting.py | 110 +- sklearn/exceptions.py | 39 +- .../test_enable_hist_gradient_boosting.py | 2 +- .../tests/test_enable_iterative_imputer.py | 2 +- sklearn/externals/_arff.py | 2 +- sklearn/externals/_pep562.py | 58 + sklearn/externals/_threadpoolctl.py | 1085 ++++++++------ sklearn/externals/joblib/__init__.py | 2 +- sklearn/externals/joblib/numpy_pickle.py | 2 +- sklearn/externals/six.py | 583 -------- sklearn/externals/vendor_threadpoolctl.sh | 2 +- sklearn/feature_extraction/__init__.py | 4 +- ...dict_vectorizer.py => _dict_vectorizer.py} | 5 +- .../{hashing.py => _hash.py} | 10 +- .../{_hashing.pyx => _hashing_fast.pyx} | 7 +- .../{stop_words.py => _stop_words.py} | 0 sklearn/feature_extraction/image.py | 53 +- sklearn/feature_extraction/setup.py | 4 +- .../tests/test_feature_hasher.py | 29 +- .../feature_extraction/tests/test_image.py | 19 +- sklearn/feature_extraction/tests/test_text.py | 32 +- sklearn/feature_extraction/text.py | 235 +++- sklearn/feature_selection/__init__.py | 30 +- .../feature_selection/{base.py => _base.py} | 4 +- .../{from_model.py => _from_model.py} | 32 +- .../{mutual_info_.py => _mutual_info.py} | 2 +- sklearn/feature_selection/{rfe.py => _rfe.py} | 39 +- ..._selection.py => _univariate_selection.py} | 34 +- ...ce_threshold.py => _variance_threshold.py} | 20 +- sklearn/feature_selection/tests/test_base.py | 2 +- sklearn/feature_selection/tests/test_chi2.py | 6 +- .../tests/test_feature_select.py | 15 +- .../tests/test_from_model.py | 73 +- .../tests/test_mutual_info.py | 7 +- sklearn/feature_selection/tests/test_rfe.py | 30 +- .../tests/test_variance_threshold.py | 14 +- sklearn/gaussian_process/__init__.py | 4 +- sklearn/gaussian_process/{gpc.py => _gpc.py} | 87 +- sklearn/gaussian_process/{gpr.py => _gpr.py} | 46 +- sklearn/gaussian_process/kernels.py | 156 ++- .../tests/_mini_sequence_kernel.py | 51 + sklearn/gaussian_process/tests/test_gpc.py | 19 +- sklearn/gaussian_process/tests/test_gpr.py | 25 +- .../gaussian_process/tests/test_kernels.py | 54 +- sklearn/impute/_base.py | 116 +- sklearn/impute/_iterative.py | 121 +- sklearn/impute/_knn.py | 62 +- sklearn/impute/tests/test_base.py | 48 + sklearn/impute/tests/test_common.py | 86 ++ sklearn/impute/tests/test_impute.py | 109 +- sklearn/impute/tests/test_knn.py | 4 +- sklearn/inspection/__init__.py | 23 +- ...l_dependence.py => _partial_dependence.py} | 554 ++++++-- ...portance.py => _permutation_importance.py} | 61 +- .../tests/test_partial_dependence.py | 297 ++-- .../tests/test_permutation_importance.py | 200 +++ .../tests/test_plot_partial_dependence.py | 455 ++++++ sklearn/isotonic.py | 23 +- sklearn/kernel_approximation.py | 8 +- sklearn/kernel_ridge.py | 12 +- sklearn/linear_model/__init__.py | 49 +- sklearn/linear_model/{base.py => _base.py} | 95 +- sklearn/linear_model/{bayes.py => _bayes.py} | 16 +- .../{cd_fast.pyx => _cd_fast.pyx} | 0 ...nate_descent.py => _coordinate_descent.py} | 72 +- sklearn/linear_model/{huber.py => _huber.py} | 4 +- .../{least_angle.py => _least_angle.py} | 10 +- .../{logistic.py => _logistic.py} | 559 +++----- sklearn/linear_model/{omp.py => _omp.py} | 2 +- ...e_aggressive.py => _passive_aggressive.py} | 16 +- .../{perceptron.py => _perceptron.py} | 38 +- .../linear_model/{ransac.py => _ransac.py} | 27 +- sklearn/linear_model/{ridge.py => _ridge.py} | 387 ++--- sklearn/linear_model/{sag.py => _sag.py} | 4 +- .../{sag_fast.pyx.tp => _sag_fast.pyx.tp} | 8 +- .../{sgd_fast.pxd => _sgd_fast.pxd} | 0 .../{sgd_fast.pyx => _sgd_fast.pyx} | 10 +- ...sgd_fast_helpers.h => _sgd_fast_helpers.h} | 0 ...ic_gradient.py => _stochastic_gradient.py} | 196 +-- .../{theil_sen.py => _theil_sen.py} | 8 +- sklearn/linear_model/setup.py | 31 +- sklearn/linear_model/tests/test_base.py | 22 +- sklearn/linear_model/tests/test_bayes.py | 12 +- .../tests/test_coordinate_descent.py | 37 +- sklearn/linear_model/tests/test_huber.py | 10 +- .../linear_model/tests/test_least_angle.py | 14 +- sklearn/linear_model/tests/test_logistic.py | 121 +- sklearn/linear_model/tests/test_omp.py | 10 +- .../tests/test_passive_aggressive.py | 31 +- sklearn/linear_model/tests/test_perceptron.py | 8 +- sklearn/linear_model/tests/test_ransac.py | 32 +- sklearn/linear_model/tests/test_ridge.py | 84 +- sklearn/linear_model/tests/test_sag.py | 16 +- sklearn/linear_model/tests/test_sgd.py | 21 +- .../tests/test_sparse_coordinate_descent.py | 11 +- sklearn/linear_model/tests/test_theil_sen.py | 6 +- sklearn/manifold/__init__.py | 13 +- sklearn/manifold/_barnes_hut_tsne.pyx | 210 +-- sklearn/manifold/{isomap.py => _isomap.py} | 78 +- .../{locally_linear.py => _locally_linear.py} | 4 +- sklearn/manifold/{mds.py => _mds.py} | 0 ...l_embedding_.py => _spectral_embedding.py} | 40 +- sklearn/manifold/{t_sne.py => _t_sne.py} | 148 +- sklearn/manifold/_utils.pyx | 30 +- sklearn/manifold/tests/test_isomap.py | 54 +- sklearn/manifold/tests/test_locally_linear.py | 6 +- sklearn/manifold/tests/test_mds.py | 2 +- .../manifold/tests/test_spectral_embedding.py | 55 +- sklearn/manifold/tests/test_t_sne.py | 260 ++-- sklearn/metrics/__init__.py | 97 +- sklearn/metrics/{base.py => _base.py} | 6 +- .../{classification.py => _classification.py} | 369 ++--- sklearn/metrics/_pairwise_fast.pyx | 110 ++ sklearn/metrics/_plot/base.py | 40 + sklearn/metrics/_plot/confusion_matrix.py | 198 +++ .../metrics/_plot/precision_recall_curve.py | 168 +++ sklearn/metrics/_plot/roc_curve.py | 68 +- .../_plot/tests/test_plot_confusion_matrix.py | 247 ++++ .../_plot/tests/test_plot_precision_recall.py | 155 ++ .../_plot/tests/test_plot_roc_curve.py | 57 +- sklearn/metrics/{ranking.py => _ranking.py} | 139 +- .../metrics/{regression.py => _regression.py} | 96 +- sklearn/metrics/{scorer.py => _scorer.py} | 44 +- sklearn/metrics/cluster/__init__.py | 40 +- .../cluster/{bicluster.py => _bicluster.py} | 0 ...ast.pyx => _expected_mutual_info_fast.pyx} | 0 .../cluster/{supervised.py => _supervised.py} | 44 +- .../{unsupervised.py => _unsupervised.py} | 11 +- sklearn/metrics/cluster/setup.py | 4 +- .../metrics/cluster/tests/test_bicluster.py | 4 +- sklearn/metrics/cluster/tests/test_common.py | 40 +- .../metrics/cluster/tests/test_supervised.py | 8 +- .../cluster/tests/test_unsupervised.py | 14 +- sklearn/metrics/pairwise.py | 31 +- sklearn/metrics/pairwise_fast.pyx | 68 - sklearn/metrics/setup.py | 4 +- sklearn/metrics/tests/test_classification.py | 440 ++++-- sklearn/metrics/tests/test_common.py | 39 +- sklearn/metrics/tests/test_pairwise.py | 27 +- sklearn/metrics/tests/test_ranking.py | 63 +- sklearn/metrics/tests/test_regression.py | 11 +- sklearn/metrics/tests/test_score_objects.py | 83 +- sklearn/mixture/__init__.py | 4 +- sklearn/mixture/{base.py => _base.py} | 0 ...yesian_mixture.py => _bayesian_mixture.py} | 31 +- ...ussian_mixture.py => _gaussian_mixture.py} | 2 +- .../mixture/tests/test_bayesian_mixture.py | 14 +- .../mixture/tests/test_gaussian_mixture.py | 37 +- sklearn/model_selection/_search.py | 48 +- sklearn/model_selection/_split.py | 95 +- sklearn/model_selection/_validation.py | 41 +- sklearn/model_selection/tests/test_search.py | 127 +- sklearn/model_selection/tests/test_split.py | 39 +- .../model_selection/tests/test_validation.py | 53 +- sklearn/multiclass.py | 70 +- sklearn/multioutput.py | 54 +- sklearn/naive_bayes.py | 462 ++++-- sklearn/neighbors/__init__.py | 27 +- .../{ball_tree.pyx => _ball_tree.pyx} | 2 +- sklearn/neighbors/{base.py => _base.py} | 398 ++++-- .../{binary_tree.pxi => _binary_tree.pxi} | 29 +- .../{classification.py => _classification.py} | 60 +- .../{dist_metrics.pxd => _dist_metrics.pxd} | 4 +- .../{dist_metrics.pyx => _dist_metrics.pyx} | 4 +- sklearn/neighbors/_graph.py | 469 +++++++ .../neighbors/{kd_tree.pyx => _kd_tree.pyx} | 2 +- sklearn/neighbors/{kde.py => _kde.py} | 56 +- sklearn/neighbors/{lof.py => _lof.py} | 40 +- sklearn/neighbors/{nca.py => _nca.py} | 2 +- ...arest_centroid.py => _nearest_centroid.py} | 12 +- .../{quad_tree.pxd => _quad_tree.pxd} | 2 +- .../{quad_tree.pyx => _quad_tree.pyx} | 22 +- .../{regression.py => _regression.py} | 37 +- .../neighbors/{typedefs.pxd => _typedefs.pxd} | 0 .../neighbors/{typedefs.pyx => _typedefs.pyx} | 0 .../{unsupervised.py => _unsupervised.py} | 46 +- sklearn/neighbors/graph.py | 184 --- sklearn/neighbors/setup.py | 20 +- sklearn/neighbors/tests/test_ball_tree.py | 63 +- sklearn/neighbors/tests/test_dist_metrics.py | 4 +- sklearn/neighbors/tests/test_graph.py | 79 ++ sklearn/neighbors/tests/test_kd_tree.py | 67 +- sklearn/neighbors/tests/test_kde.py | 19 +- sklearn/neighbors/tests/test_lof.py | 8 +- sklearn/neighbors/tests/test_nca.py | 4 +- .../neighbors/tests/test_nearest_centroid.py | 2 +- sklearn/neighbors/tests/test_neighbors.py | 204 ++- .../tests/test_neighbors_pipeline.py | 221 +++ .../neighbors/tests/test_neighbors_tree.py | 96 ++ sklearn/neighbors/tests/test_quad_tree.py | 2 +- .../neural_network/_multilayer_perceptron.py | 136 +- sklearn/neural_network/_rbm.py | 46 +- .../neural_network/_stochastic_optimizers.py | 18 +- .../neural_network/multilayer_perceptron.py | 9 - sklearn/neural_network/rbm.py | 9 - sklearn/neural_network/tests/test_mlp.py | 3 +- sklearn/neural_network/tests/test_rbm.py | 2 +- .../tests/test_stochastic_optimizers.py | 2 +- sklearn/pipeline.py | 65 +- sklearn/preprocessing/__init__.py | 46 +- sklearn/preprocessing/{data.py => _data.py} | 147 +- sklearn/preprocessing/_discretization.py | 58 +- sklearn/preprocessing/_encoders.py | 128 +- .../preprocessing/_function_transformer.py | 11 +- sklearn/preprocessing/{label.py => _label.py} | 11 +- sklearn/preprocessing/tests/test_common.py | 4 +- sklearn/preprocessing/tests/test_data.py | 107 +- .../tests/test_discretization.py | 2 +- sklearn/preprocessing/tests/test_encoders.py | 48 +- .../tests/test_function_transformer.py | 4 +- sklearn/preprocessing/tests/test_label.py | 28 +- sklearn/random_projection.py | 37 +- sklearn/semi_supervised/__init__.py | 2 +- ...l_propagation.py => _label_propagation.py} | 11 +- .../tests/test_label_propagation.py | 45 +- sklearn/setup.py | 14 +- sklearn/svm/__init__.py | 10 +- sklearn/svm/{base.py => _base.py} | 51 +- sklearn/svm/{bounds.py => _bounds.py} | 2 +- sklearn/svm/{classes.py => _classes.py} | 147 +- sklearn/svm/{liblinear.pxd => _liblinear.pxi} | 5 +- sklearn/svm/{liblinear.pyx => _liblinear.pyx} | 2 + sklearn/svm/{libsvm.pxd => _libsvm.pxi} | 2 - sklearn/svm/{libsvm.pyx => _libsvm.pyx} | 309 ++-- .../{libsvm_sparse.pyx => _libsvm_sparse.pyx} | 0 sklearn/svm/setup.py | 12 +- sklearn/svm/src/liblinear/liblinear_helper.c | 6 +- sklearn/svm/src/liblinear/linear.cpp | 195 ++- sklearn/svm/src/liblinear/linear.h | 2 +- sklearn/svm/src/libsvm/svm.cpp | 40 +- sklearn/svm/src/libsvm/svm.h | 4 +- sklearn/svm/tests/test_bounds.py | 6 +- sklearn/svm/tests/test_sparse.py | 2 +- sklearn/svm/tests/test_svm.py | 290 +++- sklearn/tests/test_base.py | 35 +- sklearn/tests/test_build.py | 32 + sklearn/tests/test_calibration.py | 4 +- sklearn/tests/test_check_build.py | 2 +- sklearn/tests/test_common.py | 48 +- sklearn/tests/test_config.py | 2 +- sklearn/tests/test_discriminant_analysis.py | 33 +- sklearn/tests/test_docstring_parameters.py | 36 +- sklearn/tests/test_dummy.py | 29 +- sklearn/tests/test_import_deprecations.py | 22 +- sklearn/tests/test_isotonic.py | 27 +- sklearn/tests/test_kernel_approximation.py | 4 +- sklearn/tests/test_kernel_ridge.py | 4 +- sklearn/tests/test_metaestimators.py | 2 +- sklearn/tests/test_multiclass.py | 18 +- sklearn/tests/test_multioutput.py | 53 +- sklearn/tests/test_naive_bayes.py | 142 +- sklearn/tests/test_pipeline.py | 108 +- sklearn/tests/test_random_projection.py | 32 +- sklearn/tree/__init__.py | 14 +- sklearn/tree/{tree.py => _classes.py} | 376 ++--- sklearn/tree/{export.py => _export.py} | 23 +- sklearn/tree/_reingold_tilford.py | 19 +- sklearn/tree/_tree.pyx | 15 +- sklearn/tree/_utils.pxd | 2 +- sklearn/tree/setup.py | 4 - sklearn/tree/tests/test_export.py | 11 + sklearn/tree/tests/test_tree.py | 83 +- sklearn/utils/__init__.py | 207 ++- .../utils/{fast_dict.pxd => _fast_dict.pxd} | 0 .../utils/{fast_dict.pyx => _fast_dict.pyx} | 0 sklearn/utils/{mask.py => _mask.py} | 0 sklearn/utils/{mocking.py => _mocking.py} | 32 +- sklearn/utils/_openmp_helpers.pyx | 62 + ...seq_dataset.pxd.tp => _seq_dataset.pxd.tp} | 2 +- ...seq_dataset.pyx.tp => _seq_dataset.pyx.tp} | 2 +- sklearn/utils/_show_versions.py | 9 +- sklearn/utils/{testing.py => _testing.py} | 101 +- sklearn/utils/_unittest_backport.py | 224 --- .../{weight_vector.pxd => _weight_vector.pxd} | 9 - .../{weight_vector.pyx => _weight_vector.pyx} | 30 +- sklearn/utils/class_weight.py | 2 +- sklearn/utils/deprecation.py | 15 +- sklearn/utils/estimator_checks.py | 418 ++++-- sklearn/utils/fixes.py | 64 +- sklearn/utils/graph_shortest_path.pyx | 2 +- sklearn/utils/linear_assignment_.py | 5 +- sklearn/utils/metaestimators.py | 6 +- sklearn/utils/multiclass.py | 10 +- sklearn/utils/openmp_helpers.pyx | 41 - sklearn/utils/optimize.py | 26 +- sklearn/utils/random.py | 9 + sklearn/utils/setup.py | 40 +- sklearn/utils/sparsefuncs.py | 4 +- sklearn/utils/sparsefuncs_fast.pyx | 17 +- sklearn/utils/tests/test_class_weight.py | 8 +- sklearn/utils/tests/test_cython_blas.py | 2 +- sklearn/utils/tests/test_deprecated_utils.py | 116 +- sklearn/utils/tests/test_deprecation.py | 11 +- sklearn/utils/tests/test_estimator_checks.py | 45 +- sklearn/utils/tests/test_extmath.py | 20 +- sklearn/utils/tests/test_fast_dict.py | 2 +- sklearn/utils/tests/test_fixes.py | 29 +- sklearn/utils/tests/test_linear_assignment.py | 3 +- sklearn/utils/tests/test_multiclass.py | 38 +- sklearn/utils/tests/test_optimize.py | 6 +- sklearn/utils/tests/test_random.py | 20 +- sklearn/utils/tests/test_seq_dataset.py | 4 +- sklearn/utils/tests/test_show_versions.py | 2 +- sklearn/utils/tests/test_testing.py | 42 +- sklearn/utils/tests/test_utils.py | 117 +- sklearn/utils/tests/test_validation.py | 382 +++-- sklearn/utils/validation.py | 336 ++++- 641 files changed, 25985 insertions(+), 13001 deletions(-) create mode 100644 .github/workflows/twitter.yml create mode 100644 benchmarks/bench_plot_hierarchical.py delete mode 100755 build_tools/circle/check_deprecated_properties.sh rename build_tools/circle/{flake8_diff.sh => linting.sh} (90%) create mode 100644 doc/contents.rst create mode 100644 doc/developers/plotting.rst delete mode 100644 doc/documentation.rst create mode 100644 doc/getting_started.rst create mode 100644 doc/images/anaconda-small.png create mode 100644 doc/images/anaconda.png create mode 100644 doc/images/axa-small.png create mode 100644 doc/images/axa.png create mode 100644 doc/images/bcg-small.png create mode 100644 doc/images/bcg.png create mode 100644 doc/images/bnp-small.png create mode 100644 doc/images/bnp.png create mode 100644 doc/images/columbia-small.png create mode 100644 doc/images/dataiku-small.png create mode 100644 doc/images/dataiku.png create mode 100644 doc/images/fnrs-logo-small.png create mode 100644 doc/images/fujitsu-small.png create mode 100644 doc/images/fujitsu.png create mode 100644 doc/images/google-small.png create mode 100644 doc/images/inria-small.png create mode 100644 doc/images/intel-small.png create mode 100644 doc/images/intel.png create mode 100644 doc/images/microsoft-small.png create mode 100644 doc/images/microsoft.png create mode 100644 doc/images/nvidia-small.png create mode 100644 doc/images/nvidia.png create mode 100644 doc/images/png-logo-inria-la-fondation.png create mode 100644 doc/images/scikit-learn-logo-small.png create mode 100644 doc/images/sloan_logo-small.png create mode 100644 doc/images/sydney-stacked-small.png create mode 100644 doc/images/telecom-small.png delete mode 100644 doc/index.rst delete mode 100644 doc/other_distributions.rst create mode 100644 doc/templates/documentation.html create mode 100644 doc/templates/index.html create mode 100644 doc/themes/scikit-learn-modern/javascript.html create mode 100644 doc/themes/scikit-learn-modern/layout.html create mode 100644 doc/themes/scikit-learn-modern/nav.html create mode 100644 doc/themes/scikit-learn-modern/search.html create mode 100644 doc/themes/scikit-learn-modern/static/css/theme.css create mode 100644 doc/themes/scikit-learn-modern/static/css/vendor/bootstrap.min.css create mode 100644 doc/themes/scikit-learn-modern/static/js/searchtools.js create mode 100644 doc/themes/scikit-learn-modern/static/js/vendor/bootstrap.min.js create mode 100644 doc/themes/scikit-learn-modern/theme.conf create mode 100644 doc/whats_new/changelog_legend.inc create mode 100644 doc/whats_new/v0.23.rst create mode 100644 examples/ensemble/plot_stack_predictors.py create mode 100644 examples/gaussian_process/plot_gpr_on_structured_data.py create mode 100644 examples/neighbors/approximate_nearest_neighbors.py create mode 100644 examples/neighbors/plot_caching_nearest_neighbors.py create mode 100644 examples/plot_partial_dependence_visualization_api.py create mode 100644 examples/release_highlights/README.txt create mode 100644 examples/release_highlights/plot_release_highlights_0_22_0.py create mode 100644 maint_tools/check_pxd_in_installation.py create mode 100644 maint_tools/test_docstrings.py create mode 100644 sklearn/_build_utils/deprecated_modules.py create mode 100644 sklearn/_build_utils/pre_build_helpers.py create mode 100644 sklearn/_distributor_init.py rename sklearn/cluster/{affinity_propagation_.py => _affinity_propagation.py} (94%) rename sklearn/cluster/{hierarchical.py => _agglomerative.py} (91%) rename sklearn/cluster/{bicluster.py => _bicluster.py} (91%) rename sklearn/cluster/{birch.py => _birch.py} (97%) rename sklearn/cluster/{dbscan_.py => _dbscan.py} (81%) rename sklearn/cluster/{_hierarchical.pyx => _hierarchical_fast.pyx} (83%) rename sklearn/cluster/{_k_means.pxd => _k_means_fast.pxd} (100%) rename sklearn/cluster/{_k_means.pyx => _k_means_fast.pyx} (100%) rename sklearn/cluster/{k_means_.py => _kmeans.py} (99%) rename sklearn/cluster/{mean_shift_.py => _mean_shift.py} (74%) rename sklearn/cluster/{optics_.py => _optics.py} (98%) rename sklearn/cluster/{spectral.py => _spectral.py} (93%) rename sklearn/covariance/{elliptic_envelope.py => _elliptic_envelope.py} (100%) rename sklearn/covariance/{empirical_covariance_.py => _empirical_covariance.py} (97%) rename sklearn/covariance/{graph_lasso_.py => _graph_lasso.py} (99%) rename sklearn/covariance/{robust_covariance.py => _robust_covariance.py} (99%) rename sklearn/covariance/{shrunk_covariance_.py => _shrunk_covariance.py} (98%) rename sklearn/cross_decomposition/{cca_.py => _cca.py} (99%) rename sklearn/cross_decomposition/{pls_.py => _pls.py} (92%) rename sklearn/datasets/{base.py => _base.py} (96%) rename sklearn/datasets/{california_housing.py => _california_housing.py} (77%) rename sklearn/datasets/{covtype.py => _covtype.py} (96%) rename sklearn/datasets/{kddcup99.py => _kddcup99.py} (98%) rename sklearn/datasets/{lfw.py => _lfw.py} (99%) rename sklearn/datasets/{olivetti_faces.py => _olivetti_faces.py} (96%) rename sklearn/datasets/{openml.py => _openml.py} (99%) rename sklearn/datasets/{rcv1.py => _rcv1.py} (98%) rename sklearn/datasets/{samples_generator.py => _samples_generator.py} (96%) rename sklearn/datasets/{species_distributions.py => _species_distributions.py} (98%) rename sklearn/datasets/{_svmlight_format.pyx => _svmlight_format_fast.pyx} (100%) rename sklearn/datasets/{svmlight_format.py => _svmlight_format_io.py} (98%) rename sklearn/datasets/{twenty_newsgroups.py => _twenty_newsgroups.py} (94%) rename sklearn/decomposition/{base.py => _base.py} (100%) rename sklearn/decomposition/{cdnmf_fast.pyx => _cdnmf_fast.pyx} (100%) rename sklearn/decomposition/{dict_learning.py => _dict_learning.py} (97%) rename sklearn/decomposition/{factor_analysis.py => _factor_analysis.py} (99%) rename sklearn/decomposition/{fastica_.py => _fastica.py} (74%) rename sklearn/decomposition/{incremental_pca.py => _incremental_pca.py} (99%) rename sklearn/decomposition/{kernel_pca.py => _kernel_pca.py} (97%) rename sklearn/decomposition/{online_lda.py => _lda.py} (98%) rename sklearn/decomposition/{nmf.py => _nmf.py} (98%) rename sklearn/decomposition/{_online_lda.pyx => _online_lda_fast.pyx} (100%) rename sklearn/decomposition/{pca.py => _pca.py} (96%) rename sklearn/decomposition/{sparse_pca.py => _sparse_pca.py} (99%) rename sklearn/decomposition/{truncated_svd.py => _truncated_svd.py} (96%) rename sklearn/ensemble/{bagging.py => _bagging.py} (93%) rename sklearn/ensemble/{base.py => _base.py} (51%) rename sklearn/ensemble/{forest.py => _forest.py} (86%) rename sklearn/ensemble/{gradient_boosting.py => _gb.py} (66%) rename sklearn/ensemble/{iforest.py => _iforest.py} (92%) create mode 100644 sklearn/ensemble/_stacking.py rename sklearn/ensemble/{voting.py => _voting.py} (78%) rename sklearn/ensemble/{weight_boosting.py => _weight_boosting.py} (88%) delete mode 100644 sklearn/ensemble/partial_dependence.py create mode 100644 sklearn/ensemble/tests/test_common.py delete mode 100644 sklearn/ensemble/tests/test_partial_dependence.py create mode 100644 sklearn/ensemble/tests/test_stacking.py create mode 100644 sklearn/externals/_pep562.py delete mode 100644 sklearn/externals/six.py rename sklearn/feature_extraction/{dict_vectorizer.py => _dict_vectorizer.py} (98%) rename sklearn/feature_extraction/{hashing.py => _hash.py} (95%) rename sklearn/feature_extraction/{_hashing.pyx => _hashing_fast.pyx} (93%) rename sklearn/feature_extraction/{stop_words.py => _stop_words.py} (100%) rename sklearn/feature_selection/{base.py => _base.py} (96%) rename sklearn/feature_selection/{from_model.py => _from_model.py} (89%) rename sklearn/feature_selection/{mutual_info_.py => _mutual_info.py} (99%) rename sklearn/feature_selection/{rfe.py => _rfe.py} (94%) rename sklearn/feature_selection/{univariate_selection.py => _univariate_selection.py} (97%) rename sklearn/feature_selection/{variance_threshold.py => _variance_threshold.py} (83%) rename sklearn/gaussian_process/{gpc.py => _gpc.py} (91%) rename sklearn/gaussian_process/{gpr.py => _gpr.py} (92%) create mode 100644 sklearn/gaussian_process/tests/_mini_sequence_kernel.py create mode 100644 sklearn/impute/tests/test_base.py create mode 100644 sklearn/impute/tests/test_common.py rename sklearn/inspection/{partial_dependence.py => _partial_dependence.py} (57%) rename sklearn/inspection/{permutation_importance.py => _permutation_importance.py} (70%) create mode 100644 sklearn/inspection/tests/test_plot_partial_dependence.py rename sklearn/linear_model/{base.py => _base.py} (88%) rename sklearn/linear_model/{bayes.py => _bayes.py} (98%) rename sklearn/linear_model/{cd_fast.pyx => _cd_fast.pyx} (100%) rename sklearn/linear_model/{coordinate_descent.py => _coordinate_descent.py} (98%) rename sklearn/linear_model/{huber.py => _huber.py} (99%) rename sklearn/linear_model/{least_angle.py => _least_angle.py} (99%) rename sklearn/linear_model/{logistic.py => _logistic.py} (82%) rename sklearn/linear_model/{omp.py => _omp.py} (99%) rename sklearn/linear_model/{passive_aggressive.py => _passive_aggressive.py} (97%) rename sklearn/linear_model/{perceptron.py => _perceptron.py} (83%) rename sklearn/linear_model/{ransac.py => _ransac.py} (96%) rename sklearn/linear_model/{ridge.py => _ridge.py} (87%) rename sklearn/linear_model/{sag.py => _sag.py} (99%) rename sklearn/linear_model/{sag_fast.pyx.tp => _sag_fast.pyx.tp} (99%) rename sklearn/linear_model/{sgd_fast.pxd => _sgd_fast.pxd} (100%) rename sklearn/linear_model/{sgd_fast.pyx => _sgd_fast.pyx} (99%) rename sklearn/linear_model/{sgd_fast_helpers.h => _sgd_fast_helpers.h} (100%) rename sklearn/linear_model/{stochastic_gradient.py => _stochastic_gradient.py} (93%) rename sklearn/linear_model/{theil_sen.py => _theil_sen.py} (98%) rename sklearn/manifold/{isomap.py => _isomap.py} (73%) rename sklearn/manifold/{locally_linear.py => _locally_linear.py} (99%) rename sklearn/manifold/{mds.py => _mds.py} (100%) rename sklearn/manifold/{spectral_embedding_.py => _spectral_embedding.py} (92%) rename sklearn/manifold/{t_sne.py => _t_sne.py} (89%) rename sklearn/metrics/{base.py => _base.py} (97%) rename sklearn/metrics/{classification.py => _classification.py} (89%) create mode 100644 sklearn/metrics/_pairwise_fast.pyx create mode 100644 sklearn/metrics/_plot/base.py create mode 100644 sklearn/metrics/_plot/confusion_matrix.py create mode 100644 sklearn/metrics/_plot/precision_recall_curve.py create mode 100644 sklearn/metrics/_plot/tests/test_plot_confusion_matrix.py create mode 100644 sklearn/metrics/_plot/tests/test_plot_precision_recall.py rename sklearn/metrics/{ranking.py => _ranking.py} (90%) rename sklearn/metrics/{regression.py => _regression.py} (89%) rename sklearn/metrics/{scorer.py => _scorer.py} (94%) rename sklearn/metrics/cluster/{bicluster.py => _bicluster.py} (100%) rename sklearn/metrics/cluster/{expected_mutual_info_fast.pyx => _expected_mutual_info_fast.pyx} (100%) rename sklearn/metrics/cluster/{supervised.py => _supervised.py} (96%) rename sklearn/metrics/cluster/{unsupervised.py => _unsupervised.py} (97%) delete mode 100644 sklearn/metrics/pairwise_fast.pyx rename sklearn/mixture/{base.py => _base.py} (100%) rename sklearn/mixture/{bayesian_mixture.py => _bayesian_mixture.py} (97%) rename sklearn/mixture/{gaussian_mixture.py => _gaussian_mixture.py} (99%) rename sklearn/neighbors/{ball_tree.pyx => _ball_tree.pyx} (99%) rename sklearn/neighbors/{base.py => _base.py} (72%) rename sklearn/neighbors/{binary_tree.pxi => _binary_tree.pxi} (99%) rename sklearn/neighbors/{classification.py => _classification.py} (91%) rename sklearn/neighbors/{dist_metrics.pxd => _dist_metrics.pxd} (96%) rename sklearn/neighbors/{dist_metrics.pyx => _dist_metrics.pyx} (99%) create mode 100644 sklearn/neighbors/_graph.py rename sklearn/neighbors/{kd_tree.pyx => _kd_tree.pyx} (99%) rename sklearn/neighbors/{kde.py => _kde.py} (86%) rename sklearn/neighbors/{lof.py => _lof.py} (94%) rename sklearn/neighbors/{nca.py => _nca.py} (99%) rename sklearn/neighbors/{nearest_centroid.py => _nearest_centroid.py} (95%) rename sklearn/neighbors/{quad_tree.pxd => _quad_tree.pxd} (99%) rename sklearn/neighbors/{quad_tree.pyx => _quad_tree.pyx} (98%) rename sklearn/neighbors/{regression.py => _regression.py} (91%) rename sklearn/neighbors/{typedefs.pxd => _typedefs.pxd} (100%) rename sklearn/neighbors/{typedefs.pyx => _typedefs.pyx} (100%) rename sklearn/neighbors/{unsupervised.py => _unsupervised.py} (72%) delete mode 100644 sklearn/neighbors/graph.py create mode 100644 sklearn/neighbors/tests/test_graph.py create mode 100644 sklearn/neighbors/tests/test_neighbors_pipeline.py create mode 100644 sklearn/neighbors/tests/test_neighbors_tree.py delete mode 100644 sklearn/neural_network/multilayer_perceptron.py delete mode 100644 sklearn/neural_network/rbm.py rename sklearn/preprocessing/{data.py => _data.py} (97%) rename sklearn/preprocessing/{label.py => _label.py} (98%) rename sklearn/semi_supervised/{label_propagation.py => _label_propagation.py} (98%) rename sklearn/svm/{base.py => _base.py} (96%) rename sklearn/svm/{bounds.py => _bounds.py} (97%) rename sklearn/svm/{classes.py => _classes.py} (95%) rename sklearn/svm/{liblinear.pxd => _liblinear.pxi} (98%) rename sklearn/svm/{liblinear.pyx => _liblinear.pyx} (99%) rename sklearn/svm/{libsvm.pxd => _libsvm.pxi} (99%) rename sklearn/svm/{libsvm.pyx => _libsvm.pyx} (69%) rename sklearn/svm/{libsvm_sparse.pyx => _libsvm_sparse.pyx} (100%) create mode 100644 sklearn/tests/test_build.py rename sklearn/tree/{tree.py => _classes.py} (85%) rename sklearn/tree/{export.py => _export.py} (98%) rename sklearn/utils/{fast_dict.pxd => _fast_dict.pxd} (100%) rename sklearn/utils/{fast_dict.pyx => _fast_dict.pyx} (100%) rename sklearn/utils/{mask.py => _mask.py} (100%) rename sklearn/utils/{mocking.py => _mocking.py} (82%) create mode 100644 sklearn/utils/_openmp_helpers.pyx rename sklearn/utils/{seq_dataset.pxd.tp => _seq_dataset.pxd.tp} (98%) rename sklearn/utils/{seq_dataset.pyx.tp => _seq_dataset.pyx.tp} (99%) rename sklearn/utils/{testing.py => _testing.py} (91%) delete mode 100644 sklearn/utils/_unittest_backport.py rename sklearn/utils/{weight_vector.pxd => _weight_vector.pxd} (84%) rename sklearn/utils/{weight_vector.pyx => _weight_vector.pyx} (87%) delete mode 100644 sklearn/utils/openmp_helpers.pyx diff --git a/.circleci/config.yml b/.circleci/config.yml index de08f2d5622f5..9fecc150ba297 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -13,6 +13,11 @@ jobs: - NUMPY_VERSION: 1.11.0 - SCIPY_VERSION: 0.17.0 - MATPLOTLIB_VERSION: 1.5.1 + # on conda, this is the latest for python 3.5 + # The following places need to be in sync with regard to Cython version: + # - .circleci config file + # - sklearn/_build_utils/__init__.py + # - advanced installation guide - CYTHON_VERSION: 0.28.5 - SCIKIT_IMAGE_VERSION: 0.12.3 steps: @@ -91,15 +96,12 @@ jobs: name: dependencies command: sudo pip install flake8 - run: - name: flake8 - command: ./build_tools/circle/flake8_diff.sh - - run: - name: deprecated_properties_checks - command: ./build_tools/circle/check_deprecated_properties.sh + name: linting + command: ./build_tools/circle/linting.sh pypy3: docker: - - image: pypy:3.6-7.1.1 + - image: pypy:3.6-7.2.0 steps: - restore_cache: keys: diff --git a/.codecov.yml b/.codecov.yml index 6f7f65294ba32..07ab69f251592 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -22,3 +22,5 @@ coverage: ignore: - "sklearn/externals" +- "sklearn/_build_utils" +- "**/setup.py" diff --git a/.coveragerc b/.coveragerc index 7f1b3b706cace..a8601458a0b07 100644 --- a/.coveragerc +++ b/.coveragerc @@ -4,5 +4,6 @@ source = sklearn parallel = True omit = */sklearn/externals/* + */sklearn/_build_utils/* */benchmarks/* **/setup.py diff --git a/.github/workflows/twitter.yml b/.github/workflows/twitter.yml new file mode 100644 index 0000000000000..d0b41e1c684a0 --- /dev/null +++ b/.github/workflows/twitter.yml @@ -0,0 +1,25 @@ +# Tweet the URL of a commit on @sklearn_commits whenever a push event +# happens on the master branch +name: Twitter Push Notification + + +on: + push: + branches: + - master + + +jobs: + tweet: + name: Twitter Notification + runs-on: ubuntu-latest + steps: + - name: Tweet URL of last commit as @sklearn_commits + uses: xorilog/twitter-action@0.1 + with: + args: "-message \"https://github.com/scikit-learn/scikit-learn/commit/${{ github.sha }}\"" + env: + TWITTER_CONSUMER_KEY: ${{ secrets.TWITTER_CONSUMER_KEY }} + TWITTER_CONSUMER_SECRET: ${{ secrets.TWITTER_CONSUMER_SECRET }} + TWITTER_ACCESS_TOKEN: ${{ secrets.TWITTER_ACCESS_TOKEN }} + TWITTER_ACCESS_SECRET: ${{ secrets.TWITTER_ACCESS_SECRET }} diff --git a/.gitignore b/.gitignore index 20483c452cd61..9b158da07a2ec 100644 --- a/.gitignore +++ b/.gitignore @@ -73,6 +73,160 @@ _configtest.o.d .mypy_cache/ # files generated from a template -sklearn/utils/seq_dataset.pyx -sklearn/utils/seq_dataset.pxd -sklearn/linear_model/sag_fast.pyx +sklearn/utils/_seq_dataset.pyx +sklearn/utils/_seq_dataset.pxd +sklearn/linear_model/_sag_fast.pyx + +# deprecated paths +# TODO: Remove in 0.24 +# All of these files should have a match in _build_utils/deprecated_modules.py +sklearn/utils/mocking.py + +sklearn/ensemble/bagging.py +sklearn/ensemble/base.py +sklearn/ensemble/forest.py +sklearn/ensemble/gradient_boosting.py +sklearn/ensemble/iforest.py +sklearn/ensemble/stacking.py +sklearn/ensemble/voting.py +sklearn/ensemble/weight_boosting.py +sklearn/tree/export.py +sklearn/tree/tree.py + +sklearn/neural_network/rbm.py +sklearn/neural_network/multilayer_perceptron.py + +sklearn/utils/weight_vector.py +sklearn/utils/seq_dataset.py +sklearn/utils/fast_dict.py +sklearn/utils/testing.py + +sklearn/cluster/affinity_propagation_.py +sklearn/cluster/bicluster.py +sklearn/cluster/birch.py +sklearn/cluster/dbscan_.py +sklearn/cluster/hierarchical.py +sklearn/cluster/k_means_.py +sklearn/cluster/mean_shift_.py +sklearn/cluster/optics_.py +sklearn/cluster/spectral.py + +sklearn/mixture/base.py +sklearn/mixture/gaussian_mixture.py +sklearn/mixture/bayesian_mixture.py + +sklearn/covariance/elliptic_envelope.py +sklearn/covariance/empirical_covariance_.py +sklearn/covariance/graph_lasso_.py +sklearn/covariance/robust_covariance.py +sklearn/covariance/shrunk_covariance_.py + +sklearn/cross_decomposition/cca_.py +sklearn/cross_decomposition/pls_.py + +sklearn/svm/base.py +sklearn/svm/classes.py +sklearn/svm/bounds.py +sklearn/svm/libsvm.py +sklearn/svm/libsvm_sparse.py +sklearn/svm/liblinear.py + +sklearn/decomposition/base.py +sklearn/decomposition/dict_learning.py +sklearn/decomposition/cdnmf_fast.py +sklearn/decomposition/factor_analysis.py +sklearn/decomposition/fastica_.py +sklearn/decomposition/incremental_pca.py +sklearn/decomposition/kernel_pca.py +sklearn/decomposition/nmf.py +sklearn/decomposition/online_lda.py +sklearn/decomposition/online_lda_fast.py +sklearn/decomposition/pca.py +sklearn/decomposition/sparse_pca.py +sklearn/decomposition/truncated_svd.py + +sklearn/gaussian_process/gpr.py +sklearn/gaussian_process/gpc.py + +sklearn/datasets/base.py +sklearn/datasets/california_housing.py +sklearn/datasets/covtype.py +sklearn/datasets/kddcup99.py +sklearn/datasets/lfw.py +sklearn/datasets/olivetti_faces.py +sklearn/datasets/openml.py +sklearn/datasets/rcv1.py +sklearn/datasets/samples_generator.py +sklearn/datasets/species_distributions.py +sklearn/datasets/svmlight_format.py +sklearn/datasets/twenty_newsgroups.py + +sklearn/feature_extraction/dict_vectorizer.py +sklearn/feature_extraction/hashing.py +sklearn/feature_extraction/stop_words.py + +sklearn/linear_model/base.py +sklearn/linear_model/bayes.py +sklearn/linear_model/cd_fast.py +sklearn/linear_model/coordinate_descent.py +sklearn/linear_model/huber.py +sklearn/linear_model/least_angle.py +sklearn/linear_model/logistic.py +sklearn/linear_model/omp.py +sklearn/linear_model/passive_aggressive.py +sklearn/linear_model/perceptron.py +sklearn/linear_model/ransac.py +sklearn/linear_model/ridge.py +sklearn/linear_model/sag.py +sklearn/linear_model/sag_fast.py +sklearn/linear_model/sgd_fast.py +sklearn/linear_model/stochastic_gradient.py +sklearn/linear_model/theil_sen.py + +sklearn/metrics/cluster/bicluster.py +sklearn/metrics/cluster/supervised.py +sklearn/metrics/cluster/unsupervised.py +sklearn/metrics/cluster/expected_mutual_info_fast.py + +sklearn/metrics/base.py +sklearn/metrics/classification.py +sklearn/metrics/regression.py +sklearn/metrics/ranking.py +sklearn/metrics/pairwise_fast.py +sklearn/metrics/scorer.py + +sklearn/inspection/partial_dependence.py +sklearn/inspection/permutation_importance.py + +sklearn/neighbors/ball_tree.py +sklearn/neighbors/base.py +sklearn/neighbors/classification.py +sklearn/neighbors/dist_metrics.py +sklearn/neighbors/graph.py +sklearn/neighbors/kd_tree.py +sklearn/neighbors/kde.py +sklearn/neighbors/lof.py +sklearn/neighbors/nca.py +sklearn/neighbors/nearest_centroid.py +sklearn/neighbors/quad_tree.py +sklearn/neighbors/regression.py +sklearn/neighbors/typedefs.py +sklearn/neighbors/unsupervised.py + +sklearn/manifold/isomap.py +sklearn/manifold/locally_linear.py +sklearn/manifold/mds.py +sklearn/manifold/spectral_embedding_.py +sklearn/manifold/t_sne.py + +sklearn/semi_supervised/label_propagation.py + +sklearn/preprocessing/data.py +sklearn/preprocessing/label.py + +sklearn/feature_selection/base.py +sklearn/feature_selection/from_model.py +sklearn/feature_selection/mutual_info.py +sklearn/feature_selection/rfe.py +sklearn/feature_selection/univariate_selection.py +sklearn/feature_selection/variance_threshold.py diff --git a/.travis.yml b/.travis.yml index 6dff0237ba60c..9fda90f71a7c0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -22,7 +22,7 @@ matrix: # installed from their CI wheels in a virtualenv with the Python # interpreter provided by travis. - python: 3.7 - env: DISTRIB="scipy-dev" CHECK_WARNINGS="true" + env: CHECK_WARNINGS="true" if: type = cron OR commit_message =~ /\[scipy-dev\]/ install: source build_tools/travis/install.sh diff --git a/MANIFEST.in b/MANIFEST.in index e36adcae38b0e..04d62596bbf3d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -5,4 +5,3 @@ recursive-include sklearn *.c *.h *.pyx *.pxd *.pxi *.tp recursive-include sklearn/datasets *.csv *.csv.gz *.rst *.jpg *.txt *.arff.gz *.json.gz include COPYING include README.rst - diff --git a/Makefile b/Makefile index 3980d8cfc2281..43fc5afe63361 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,8 @@ clean-ctags: clean: clean-ctags $(PYTHON) setup.py clean rm -rf dist + # TODO: Remove in when all modules are removed. + $(PYTHON) sklearn/_build_utils/deprecated_modules.py in: inplace # just a shortcut inplace: @@ -65,4 +67,4 @@ code-analysis: pylint -E -i y sklearn/ -d E1103,E0611,E1101 flake8-diff: - ./build_tools/circle/flake8_diff.sh + ./build_tools/circle/linting.sh diff --git a/README.rst b/README.rst index 12dccbecd6802..41197e178904a 100644 --- a/README.rst +++ b/README.rst @@ -1,6 +1,6 @@ .. -*- mode: rst -*- -|Azure|_ |Travis|_ |Codecov|_ |CircleCI|_ |Python35|_ |PyPi|_ |DOI|_ +|Azure|_ |Travis|_ |Codecov|_ |CircleCI|_ |PythonVersion|_ |PyPi|_ |DOI|_ .. |Azure| image:: https://dev.azure.com/scikit-learn/scikit-learn/_apis/build/status/scikit-learn.scikit-learn?branchName=master .. _Azure: https://dev.azure.com/scikit-learn/scikit-learn/_build/latest?definitionId=1&branchName=master @@ -14,8 +14,8 @@ .. |CircleCI| image:: https://circleci.com/gh/scikit-learn/scikit-learn/tree/master.svg?style=shield&circle-token=:circle-token .. _CircleCI: https://circleci.com/gh/scikit-learn/scikit-learn -.. |Python35| image:: https://img.shields.io/badge/python-3.5-blue.svg -.. _Python35: https://badge.fury.io/py/scikit-learn +.. |PythonVersion| image:: https://img.shields.io/pypi/pyversions/scikit-learn.svg +.. _PythonVersion: https://img.shields.io/pypi/pyversions/scikit-learn.svg .. |PyPi| image:: https://badge.fury.io/py/scikit-learn.svg .. _PyPi: https://badge.fury.io/py/scikit-learn @@ -31,7 +31,7 @@ SciPy and is distributed under the 3-Clause BSD license. The project was started in 2007 by David Cournapeau as a Google Summer of Code project, and since then many volunteers have contributed. See -the `About us `_ page +the `About us `__ page for a list of core contributors. It is currently maintained by a team of volunteers. @@ -55,7 +55,7 @@ scikit-learn requires: **Scikit-learn 0.20 was the last version to support Python 2.7 and Python 3.4.** scikit-learn 0.21 and later require Python 3.5 or newer. -Scikit-learn plotting capabilities (i.e., functions start with "plot_" +Scikit-learn plotting capabilities (i.e., functions start with ``plot_`` and classes end with "Display") require Matplotlib (>= 1.5.1). For running the examples Matplotlib >= 1.5.1 is required. A few examples require scikit-image >= 0.12.3, a few examples require pandas >= 0.18.0. @@ -138,7 +138,7 @@ Project History The project was started in 2007 by David Cournapeau as a Google Summer of Code project, and since then many volunteers have contributed. See -the `About us `_ page +the `About us `__ page for a list of core contributors. The project is currently maintained by a team of volunteers. diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 56fb99974ae52..e2ff71802ce72 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -1,9 +1,51 @@ # Adapted from https://github.com/pandas-dev/pandas/blob/master/azure-pipelines.yml jobs: +- job: linting + displayName: Linting + pool: + vmImage: ubuntu-16.04 + steps: + - bash: echo "##vso[task.prependpath]$CONDA/bin" + displayName: Add conda to PATH + - bash: sudo chown -R $USER $CONDA + displayName: Take ownership of conda installation + - bash: conda create --name flake8_env --yes flake8 + displayName: Install flake8 + - bash: | + if [[ $BUILD_SOURCEVERSIONMESSAGE =~ \[lint\ skip\] ]]; then + # skip linting + echo "Skipping linting" + exit 0 + else + source activate flake8_env + ./build_tools/circle/linting.sh + fi + displayName: Run linting + + +# Will run all the time regardless of linting outcome. +- template: build_tools/azure/posix.yml + parameters: + name: Linux_Runs + vmImage: ubuntu-16.04 + matrix: + pylatest_conda_mkl: + DISTRIB: 'conda' + PYTHON_VERSION: '*' + BLAS: 'mkl' + NUMPY_VERSION: '*' + SCIPY_VERSION: '*' + CYTHON_VERSION: '*' + PILLOW_VERSION: '*' + PYTEST_VERSION: '*' + JOBLIB_VERSION: '*' + COVERAGE: 'true' + - template: build_tools/azure/posix.yml parameters: name: Linux vmImage: ubuntu-16.04 + dependsOn: [linting] matrix: # Linux environment to test that scikit-learn can be built against # versions of numpy, scipy with ATLAS that comes with Ubuntu Xenial 16.04 @@ -12,17 +54,17 @@ jobs: DISTRIB: 'ubuntu' PYTHON_VERSION: '3.5' JOBLIB_VERSION: '0.11' - SKLEARN_NO_OPENMP: 'True' # Linux + Python 3.5 build with OpenBLAS and without SITE_JOBLIB py35_conda_openblas: DISTRIB: 'conda' PYTHON_VERSION: '3.5' - INSTALL_MKL: 'false' + BLAS: 'openblas' NUMPY_VERSION: '1.11.0' SCIPY_VERSION: '0.17.0' PANDAS_VERSION: '*' CYTHON_VERSION: '*' - PYTEST_VERSION: '*' + # temporary pin pytest due to unknown failure with pytest 5.3 + PYTEST_VERSION: '5.2' PILLOW_VERSION: '4.0.0' MATPLOTLIB_VERSION: '1.5.1' # later version of joblib are not packaged in conda for Python 3.5 @@ -31,8 +73,9 @@ jobs: # Linux environment to test the latest available dependencies and MKL. # It runs tests requiring pandas and PyAMG. pylatest_pip_openblas_pandas: - DISTRIB: 'conda-latest' - PYTHON_VERSION: '*' + DISTRIB: 'conda-pip-latest' + # FIXME: pinned until SciPy wheels are available for Python 3.8 + PYTHON_VERSION: '3.8' PYTEST_VERSION: '4.6.2' COVERAGE: 'true' CHECK_PYTEST_SOFT_DEPENDENCY: 'true' @@ -43,22 +86,34 @@ jobs: parameters: name: Linux32 vmImage: ubuntu-16.04 + dependsOn: [linting] matrix: py35_ubuntu_atlas_32bit: DISTRIB: 'ubuntu-32' PYTHON_VERSION: '3.5' JOBLIB_VERSION: '0.11' - SKLEARN_NO_OPENMP: 'True' - template: build_tools/azure/posix.yml parameters: name: macOS vmImage: xcode9-macos10.13 + dependsOn: [linting] matrix: pylatest_conda_mkl: DISTRIB: 'conda' PYTHON_VERSION: '*' - INSTALL_MKL: 'true' + BLAS: 'mkl' + NUMPY_VERSION: '*' + SCIPY_VERSION: '*' + CYTHON_VERSION: '*' + PILLOW_VERSION: '*' + PYTEST_VERSION: '*' + JOBLIB_VERSION: '*' + COVERAGE: 'true' + pylatest_conda_mkl_no_openmp: + DISTRIB: 'conda' + PYTHON_VERSION: '*' + BLAS: 'mkl' NUMPY_VERSION: '*' SCIPY_VERSION: '*' CYTHON_VERSION: '*' @@ -66,11 +121,14 @@ jobs: PYTEST_VERSION: '*' JOBLIB_VERSION: '*' COVERAGE: 'true' + SKLEARN_TEST_NO_OPENMP: 'true' + SKLEARN_SKIP_OPENMP_TEST: 'true' - template: build_tools/azure/windows.yml parameters: name: Windows vmImage: vs2017-win2016 + dependsOn: [linting] matrix: py37_conda_mkl: PYTHON_VERSION: '3.7' diff --git a/benchmarks/bench_glmnet.py b/benchmarks/bench_glmnet.py index b05971ba1ff20..e8841cba46d57 100644 --- a/benchmarks/bench_glmnet.py +++ b/benchmarks/bench_glmnet.py @@ -19,7 +19,7 @@ import numpy as np import gc from time import time -from sklearn.datasets.samples_generator import make_regression +from sklearn.datasets import make_regression alpha = 0.1 # alpha = 0.01 diff --git a/benchmarks/bench_lasso.py b/benchmarks/bench_lasso.py index 7ed774ad2e790..33054b505ce12 100644 --- a/benchmarks/bench_lasso.py +++ b/benchmarks/bench_lasso.py @@ -15,7 +15,7 @@ from time import time import numpy as np -from sklearn.datasets.samples_generator import make_regression +from sklearn.datasets import make_regression def compute_bench(alpha, n_samples, n_features, precompute): diff --git a/benchmarks/bench_multilabel_metrics.py b/benchmarks/bench_multilabel_metrics.py index d92dae0e0407c..36fc7cb3c47b8 100755 --- a/benchmarks/bench_multilabel_metrics.py +++ b/benchmarks/bench_multilabel_metrics.py @@ -16,7 +16,7 @@ from sklearn.datasets import make_multilabel_classification from sklearn.metrics import (f1_score, accuracy_score, hamming_loss, jaccard_similarity_score) -from sklearn.utils.testing import ignore_warnings +from sklearn.utils._testing import ignore_warnings METRICS = { diff --git a/benchmarks/bench_plot_fastkmeans.py b/benchmarks/bench_plot_fastkmeans.py index a0dc7f5086067..7409232c1edab 100644 --- a/benchmarks/bench_plot_fastkmeans.py +++ b/benchmarks/bench_plot_fastkmeans.py @@ -4,7 +4,7 @@ import numpy as np from numpy import random as nr -from sklearn.cluster.k_means_ import KMeans, MiniBatchKMeans +from sklearn.cluster import KMeans, MiniBatchKMeans def compute_bench(samples_range, features_range): diff --git a/benchmarks/bench_plot_hierarchical.py b/benchmarks/bench_plot_hierarchical.py new file mode 100644 index 0000000000000..3c8cd4464a771 --- /dev/null +++ b/benchmarks/bench_plot_hierarchical.py @@ -0,0 +1,85 @@ +from collections import defaultdict +from time import time + +import numpy as np +from numpy import random as nr + +from sklearn.cluster import AgglomerativeClustering + + +def compute_bench(samples_range, features_range): + + it = 0 + results = defaultdict(lambda: []) + + max_it = len(samples_range) * len(features_range) + for n_samples in samples_range: + for n_features in features_range: + it += 1 + print('==============================') + print('Iteration %03d of %03d' % (it, max_it)) + print('n_samples %05d; n_features %02d' % (n_samples, n_features)) + print('==============================') + print() + data = nr.randint(-50, 51, (n_samples, n_features)) + + for linkage in ("single", "average", "complete", "ward"): + print(linkage.capitalize()) + tstart = time() + AgglomerativeClustering( + linkage=linkage, + n_clusters=10 + ).fit(data) + + delta = time() - tstart + print("Speed: %0.3fs" % delta) + print() + + results[linkage].append(delta) + + return results + + +if __name__ == '__main__': + import matplotlib.pyplot as plt + + samples_range = np.linspace(1000, 15000, 8).astype(np.int) + features_range = np.array([2, 10, 20, 50]) + + results = compute_bench(samples_range, features_range) + + max_time = max([max(i) for i in [t for (label, t) in results.items()]]) + + colors = plt.get_cmap('tab10')(np.linspace(0, 1, 10))[:4] + lines = {linkage: None for linkage in results.keys()} + fig, axs = plt.subplots(2, 2, sharex=True, sharey=True) + fig.suptitle( + 'Scikit-learn agglomerative clustering benchmark results', + fontsize=16 + ) + for c, (label, timings) in zip(colors, + sorted(results.items())): + timing_by_samples = np.asarray(timings).reshape( + samples_range.shape[0], + features_range.shape[0] + ) + + for n in range(timing_by_samples.shape[1]): + ax = axs.flatten()[n] + lines[label], = ax.plot( + samples_range, + timing_by_samples[:, n], + color=c, + label=label + ) + ax.set_title('n_features = %d' % features_range[n]) + if n >= 2: + ax.set_xlabel('n_samples') + if n % 2 == 0: + ax.set_ylabel('time (s)') + + fig.subplots_adjust(right=0.8) + fig.legend([lines[link] for link in sorted(results.keys())], + sorted(results.keys()), loc="center right", fontsize=8) + + plt.show() diff --git a/benchmarks/bench_plot_lasso_path.py b/benchmarks/bench_plot_lasso_path.py index ee9ce5bd98a64..8087928b1811d 100644 --- a/benchmarks/bench_plot_lasso_path.py +++ b/benchmarks/bench_plot_lasso_path.py @@ -11,7 +11,7 @@ from sklearn.linear_model import lars_path, lars_path_gram from sklearn.linear_model import lasso_path -from sklearn.datasets.samples_generator import make_regression +from sklearn.datasets import make_regression def compute_bench(samples_range, features_range): diff --git a/benchmarks/bench_plot_nmf.py b/benchmarks/bench_plot_nmf.py index d8d34d8f952ce..48f1dd1891392 100644 --- a/benchmarks/bench_plot_nmf.py +++ b/benchmarks/bench_plot_nmf.py @@ -16,12 +16,12 @@ from joblib import Memory import pandas -from sklearn.utils.testing import ignore_warnings +from sklearn.utils._testing import ignore_warnings from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.decomposition.nmf import NMF -from sklearn.decomposition.nmf import _initialize_nmf -from sklearn.decomposition.nmf import _beta_divergence -from sklearn.decomposition.nmf import _check_init +from sklearn.decomposition import NMF +from sklearn.decomposition._nmf import _initialize_nmf +from sklearn.decomposition._nmf import _beta_divergence +from sklearn.decomposition._nmf import _check_init from sklearn.exceptions import ConvergenceWarning from sklearn.utils.extmath import safe_sparse_dot, squared_norm from sklearn.utils import check_array diff --git a/benchmarks/bench_plot_omp_lars.py b/benchmarks/bench_plot_omp_lars.py index d762acd619c1d..48a73a60d2fdb 100644 --- a/benchmarks/bench_plot_omp_lars.py +++ b/benchmarks/bench_plot_omp_lars.py @@ -10,7 +10,7 @@ import numpy as np from sklearn.linear_model import lars_path, lars_path_gram, orthogonal_mp -from sklearn.datasets.samples_generator import make_sparse_coded_signal +from sklearn.datasets import make_sparse_coded_signal def compute_bench(samples_range, features_range): diff --git a/benchmarks/bench_plot_randomized_svd.py b/benchmarks/bench_plot_randomized_svd.py index e2c61223a5a5c..e322cda8e87e9 100644 --- a/benchmarks/bench_plot_randomized_svd.py +++ b/benchmarks/bench_plot_randomized_svd.py @@ -77,8 +77,7 @@ from sklearn.utils import gen_batches from sklearn.utils.validation import check_random_state from sklearn.utils.extmath import randomized_svd -from sklearn.datasets.samples_generator import (make_low_rank_matrix, - make_sparse_uncorrelated) +from sklearn.datasets import make_low_rank_matrix, make_sparse_uncorrelated from sklearn.datasets import (fetch_lfw_people, fetch_openml, fetch_20newsgroups_vectorized, @@ -105,7 +104,7 @@ # in case the reconstructed (dense) matrix is too large MAX_MEMORY = np.int(2e9) -# The following datasets can be dowloaded manually from: +# The following datasets can be downloaded manually from: # CIFAR 10: https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz # SVHN: http://ufldl.stanford.edu/housenumbers/train_32x32.mat CIFAR_FOLDER = "./cifar-10-batches-py/" diff --git a/benchmarks/bench_plot_svd.py b/benchmarks/bench_plot_svd.py index 746c0df989e90..406fd9ec21f01 100644 --- a/benchmarks/bench_plot_svd.py +++ b/benchmarks/bench_plot_svd.py @@ -9,7 +9,7 @@ from scipy.linalg import svd from sklearn.utils.extmath import randomized_svd -from sklearn.datasets.samples_generator import make_low_rank_matrix +from sklearn.datasets import make_low_rank_matrix def compute_bench(samples_range, features_range, n_iter=3, rank=50): diff --git a/benchmarks/bench_rcv1_logreg_convergence.py b/benchmarks/bench_rcv1_logreg_convergence.py index 52a2cb1a4f33c..051496c4483a2 100644 --- a/benchmarks/bench_rcv1_logreg_convergence.py +++ b/benchmarks/bench_rcv1_logreg_convergence.py @@ -11,7 +11,7 @@ from sklearn.linear_model import (LogisticRegression, SGDClassifier) from sklearn.datasets import fetch_rcv1 -from sklearn.linear_model.sag import get_auto_step_size +from sklearn.linear_model._sag import get_auto_step_size try: import lightning.classification as lightning_clf diff --git a/benchmarks/bench_sgd_regression.py b/benchmarks/bench_sgd_regression.py index d0b9f43f7f590..4c5123c9b6e61 100644 --- a/benchmarks/bench_sgd_regression.py +++ b/benchmarks/bench_sgd_regression.py @@ -10,7 +10,7 @@ from sklearn.linear_model import Ridge, SGDRegressor, ElasticNet from sklearn.metrics import mean_squared_error -from sklearn.datasets.samples_generator import make_regression +from sklearn.datasets import make_regression """ Benchmark for SGD regression diff --git a/benchmarks/bench_sparsify.py b/benchmarks/bench_sparsify.py index dd2d6c0f59751..be1f3bffe0181 100644 --- a/benchmarks/bench_sparsify.py +++ b/benchmarks/bench_sparsify.py @@ -45,7 +45,7 @@ from scipy.sparse.csr import csr_matrix import numpy as np -from sklearn.linear_model.stochastic_gradient import SGDRegressor +from sklearn.linear_model import SGDRegressor from sklearn.metrics import r2_score np.random.seed(42) diff --git a/benchmarks/bench_text_vectorizers.py b/benchmarks/bench_text_vectorizers.py index 196e677e9b49c..96dbc04312291 100644 --- a/benchmarks/bench_text_vectorizers.py +++ b/benchmarks/bench_text_vectorizers.py @@ -32,7 +32,7 @@ def f(): text = fetch_20newsgroups(subset='train').data[:1000] print("="*80 + '\n#' + " Text vectorizers benchmark" + '\n' + '='*80 + '\n') -print("Using a subset of the 20 newsrgoups dataset ({} documents)." +print("Using a subset of the 20 newsgroups dataset ({} documents)." .format(len(text))) print("This benchmarks runs in ~1 min ...") diff --git a/benchmarks/bench_tsne_mnist.py b/benchmarks/bench_tsne_mnist.py index d36c7af2bff52..8f58a3a41a7e3 100644 --- a/benchmarks/bench_tsne_mnist.py +++ b/benchmarks/bench_tsne_mnist.py @@ -21,7 +21,7 @@ from sklearn.decomposition import PCA from sklearn.utils import check_array from sklearn.utils import shuffle as _shuffle - +from sklearn.utils._openmp_helpers import _openmp_effective_n_threads LOG_DIR = "mnist_tsne_output" if not os.path.exists(LOG_DIR): @@ -86,6 +86,7 @@ def sanitize(filename): "preprocessing.") args = parser.parse_args() + print("Used number of threads: {}".format(_openmp_effective_n_threads())) X, y = load_data(order=args.order) if args.pca_components > 0: @@ -141,7 +142,7 @@ def bhtsne(X): data_size.append(70000) results = [] - basename, _ = os.path.splitext(__file__) + basename = os.path.basename(os.path.splitext(__file__)[0]) log_filename = os.path.join(LOG_DIR, basename + '.json') for n in data_size: X_train = X[:n] diff --git a/build_tools/azure/install.cmd b/build_tools/azure/install.cmd index 1c7ebae521904..2566ba4f4f3aa 100644 --- a/build_tools/azure/install.cmd +++ b/build_tools/azure/install.cmd @@ -11,9 +11,15 @@ IF "%PYTHON_ARCH%"=="64" ( call deactivate @rem Clean up any left-over from a previous build conda remove --all -q -y -n %VIRTUALENV% - conda create -n %VIRTUALENV% -q -y python=%PYTHON_VERSION% numpy scipy cython matplotlib pytest=%PYTEST_VERSION% wheel pillow joblib + conda create -n %VIRTUALENV% -q -y python=%PYTHON_VERSION% numpy scipy cython matplotlib wheel pillow joblib call activate %VIRTUALENV% + + IF "%PYTEST_VERSION%"=="*" ( + pip install pytest + ) else ( + pip install pytest==%PYTEST_VERSION% + ) pip install pytest-xdist ) else ( pip install numpy scipy cython pytest wheel pillow joblib diff --git a/build_tools/azure/install.sh b/build_tools/azure/install.sh index 81726d037cca4..aa2707bb03837 100755 --- a/build_tools/azure/install.sh +++ b/build_tools/azure/install.sh @@ -4,20 +4,6 @@ set -e UNAMESTR=`uname` -if [[ "$UNAMESTR" == "Darwin" ]]; then - # install OpenMP not present by default on osx - HOMEBREW_NO_AUTO_UPDATE=1 brew install libomp - - # enable OpenMP support for Apple-clang - export CC=/usr/bin/clang - export CXX=/usr/bin/clang++ - export CPPFLAGS="$CPPFLAGS -Xpreprocessor -fopenmp" - export CFLAGS="$CFLAGS -I/usr/local/opt/libomp/include" - export CXXFLAGS="$CXXFLAGS -I/usr/local/opt/libomp/include" - export LDFLAGS="$LDFLAGS -L/usr/local/opt/libomp/lib -lomp" - export DYLD_LIBRARY_PATH=/usr/local/opt/libomp/lib -fi - make_conda() { TO_INSTALL="$@" conda create -n $VIRTUALENV --yes $TO_INSTALL @@ -25,24 +11,19 @@ make_conda() { } version_ge() { - # The two version numbers are seperated with a new line is piped to sort + # The two version numbers are separated with a new line is piped to sort # -rV. The -V activates for version number sorting and -r sorts in - # decending order. If the first argument is the top element of the sort, it + # descending order. If the first argument is the top element of the sort, it # is greater than or equal to the second argument. test "$(printf "${1}\n${2}" | sort -rV | head -n 1)" == "$1" } if [[ "$DISTRIB" == "conda" ]]; then - TO_INSTALL="python=$PYTHON_VERSION pip pytest=$PYTEST_VERSION \ - pytest-cov numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION \ - cython=$CYTHON_VERSION joblib=$JOBLIB_VERSION" - - if [[ "$INSTALL_MKL" == "true" ]]; then - TO_INSTALL="$TO_INSTALL mkl" - else - TO_INSTALL="$TO_INSTALL nomkl" - fi + TO_INSTALL="python=$PYTHON_VERSION pip \ + numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION \ + cython=$CYTHON_VERSION joblib=$JOBLIB_VERSION\ + blas[build=$BLAS]" if [[ -n "$PANDAS_VERSION" ]]; then TO_INSTALL="$TO_INSTALL pandas=$PANDAS_VERSION" @@ -60,6 +41,14 @@ if [[ "$DISTRIB" == "conda" ]]; then TO_INSTALL="$TO_INSTALL matplotlib=$MATPLOTLIB_VERSION" fi + if [[ "$UNAMESTR" == "Darwin" ]]; then + if [[ "$SKLEARN_TEST_NO_OPENMP" != "true" ]]; then + # on macOS, install an OpenMP-enabled clang/llvm from conda-forge. + TO_INSTALL="$TO_INSTALL conda-forge::compilers \ + conda-forge::llvm-openmp" + fi + fi + # Old packages coming from the 'free' conda channel have been removed but # we are using them for testing Python 3.5. See # https://www.anaconda.com/why-we-removed-the-free-channel-in-conda-4-7/ @@ -70,12 +59,20 @@ if [[ "$DISTRIB" == "conda" ]]; then fi make_conda $TO_INSTALL + + if [[ "$PYTEST_VERSION" == "*" ]]; then + python -m pip install pytest + else + python -m pip install pytest=="$PYTEST_VERSION" + fi + if [[ "$PYTHON_VERSION" == "*" ]]; then - pip install pytest-xdist + python -m pip install pytest-xdist fi elif [[ "$DISTRIB" == "ubuntu" ]]; then sudo add-apt-repository --remove ppa:ubuntu-toolchain-r/test + sudo apt-get update sudo apt-get install python3-scipy python3-matplotlib libatlas3-base libatlas-base-dev libatlas-dev python3-virtualenv python3 -m virtualenv --system-site-packages --python=python3 $VIRTUALENV source $VIRTUALENV/bin/activate @@ -86,17 +83,19 @@ elif [[ "$DISTRIB" == "ubuntu-32" ]]; then python3 -m virtualenv --system-site-packages --python=python3 $VIRTUALENV source $VIRTUALENV/bin/activate python -m pip install pytest==$PYTEST_VERSION pytest-cov cython joblib==$JOBLIB_VERSION -elif [[ "$DISTRIB" == "conda-latest" ]]; then - # since conda main channel usually lacks behind on the latest releases, +elif [[ "$DISTRIB" == "conda-pip-latest" ]]; then + # Since conda main channel usually lacks behind on the latest releases, # we use pypi to test against the latest releases of the dependencies. + # conda is still used as a convenient way to install Python and pip. make_conda "python=$PYTHON_VERSION" - python -m pip install numpy scipy joblib cython + python -m pip install -U pip + python -m pip install numpy scipy cython joblib python -m pip install pytest==$PYTEST_VERSION pytest-cov pytest-xdist - python -m pip install pandas matplotlib pyamg pillow + python -m pip install pandas matplotlib pyamg fi if [[ "$COVERAGE" == "true" ]]; then - python -m pip install coverage codecov + python -m pip install coverage codecov pytest-cov fi if [[ "$TEST_DOCSTRINGS" == "true" ]]; then @@ -117,6 +116,9 @@ try: except ImportError: print('pandas not installed') " -pip list +python -m pip list + +# Use setup.py instead of `pip install -e .` to be able to pass the -j flag +# to speed-up the building multicore CI machines. python setup.py build_ext --inplace -j 3 python setup.py develop diff --git a/build_tools/azure/posix-32.yml b/build_tools/azure/posix-32.yml index 127630b61ca65..68e05e347f307 100644 --- a/build_tools/azure/posix-32.yml +++ b/build_tools/azure/posix-32.yml @@ -2,16 +2,18 @@ parameters: name: '' vmImage: '' matrix: [] + dependsOn: [] jobs: - job: ${{ parameters.name }} + dependsOn: ${{ parameters.dependsOn }} pool: vmImage: ${{ parameters.vmImage }} variables: TEST_DIR: '$(Agent.WorkFolder)/tmp_folder' JUNITXML: 'test-data.xml' OMP_NUM_THREADS: '4' - PYTEST_VERSION: '3.8.1' + PYTEST_VERSION: '5.2.1' OPENBLAS_NUM_THREADS: '4' SKLEARN_SKIP_NETWORK_TESTS: '1' strategy: @@ -35,7 +37,6 @@ jobs: -e VIRTUALENV=testvenv -e JOBLIB_VERSION=$JOBLIB_VERSION -e PYTEST_VERSION=$PYTEST_VERSION - -e SKLEARN_NO_OPENMP=$SKLEARN_NO_OPENMP -e OMP_NUM_THREADS=$OMP_NUM_THREADS -e OPENBLAS_NUM_THREADS=$OPENBLAS_NUM_THREADS -e SKLEARN_SKIP_NETWORK_TESTS=$SKLEARN_SKIP_NETWORK_TESTS diff --git a/build_tools/azure/posix.yml b/build_tools/azure/posix.yml index 13bce4963cae9..f5c4a023b4c39 100644 --- a/build_tools/azure/posix.yml +++ b/build_tools/azure/posix.yml @@ -2,16 +2,18 @@ parameters: name: '' vmImage: '' matrix: [] + dependsOn: [] jobs: - job: ${{ parameters.name }} + dependsOn: ${{ parameters.dependsOn }} pool: vmImage: ${{ parameters.vmImage }} variables: TEST_DIR: '$(Agent.WorkFolder)/tmp_folder' VIRTUALENV: 'testvenv' JUNITXML: 'test-data.xml' - PYTEST_VERSION: '3.8.1' + PYTEST_VERSION: '5.2.1' OMP_NUM_THREADS: '4' OPENBLAS_NUM_THREADS: '4' SKLEARN_SKIP_NETWORK_TESTS: '1' diff --git a/build_tools/azure/windows.yml b/build_tools/azure/windows.yml index e5a1eaf5fd9ce..24b542b227dd8 100644 --- a/build_tools/azure/windows.yml +++ b/build_tools/azure/windows.yml @@ -3,16 +3,18 @@ parameters: name: '' vmImage: '' matrix: [] + dependsOn: [] jobs: - job: ${{ parameters.name }} + dependsOn: ${{ parameters.dependsOn }} pool: vmImage: ${{ parameters.vmImage }} variables: VIRTUALENV: 'testvenv' JUNITXML: 'test-data.xml' SKLEARN_SKIP_NETWORK_TESTS: '1' - PYTEST_VERSION: '3.8.1' + PYTEST_VERSION: '5.2.1' TMP_FOLDER: '$(Agent.WorkFolder)\tmp_folder' strategy: matrix: diff --git a/build_tools/circle/build_doc.sh b/build_tools/circle/build_doc.sh index 5f5037319a37d..abc823facee15 100755 --- a/build_tools/circle/build_doc.sh +++ b/build_tools/circle/build_doc.sh @@ -58,6 +58,44 @@ get_build_type() { return fi changed_examples=$(echo "$filenames" | grep -E "^examples/(.*/)*plot_") + + # The following is used to extract the list of filenames of example python + # files that sphinx-gallery needs to run to generate png files used as + # figures or images in the .rst files from the documentation. + # If the contributor changes a .rst file in a PR we need to run all + # the examples mentioned in that file to get sphinx build the + # documentation without generating spurious warnings related to missing + # png files. + + if [[ -n "$filenames" ]] + then + # get rst files + rst_files="$(echo "$filenames" | grep -E "rst$")" + + # get lines with figure or images + img_fig_lines="$(echo "$rst_files" | xargs grep -shE "(figure|image)::")" + + # get only auto_examples + auto_example_files="$(echo "$img_fig_lines" | grep auto_examples | awk -F "/" '{print $NF}')" + + # remove "sphx_glr_" from path and accept replace _(\d\d\d|thumb).png with .py + scripts_names="$(echo "$auto_example_files" | sed 's/sphx_glr_//' | sed -E 's/_([[:digit:]][[:digit:]][[:digit:]]|thumb).png/.py/')" + + # get unique values + examples_in_rst="$(echo "$scripts_names" | uniq )" + fi + + # executed only if there are examples in the modified rst files + if [[ -n "$examples_in_rst" ]] + then + if [[ -n "$changed_examples" ]] + then + changed_examples="$changed_examples|$examples_in_rst" + else + changed_examples="$examples_in_rst" + fi + fi + if [[ -n "$changed_examples" ]] then echo BUILD: detected examples/ filename modified in $git_range: $changed_examples @@ -125,16 +163,17 @@ if [[ "$CIRCLE_JOB" == "doc-min-dependencies" ]]; then conda config --set restore_free_channel true fi +# packaging won't be needed once setuptools starts shipping packaging>=17.0 conda create -n $CONDA_ENV_NAME --yes --quiet python="${PYTHON_VERSION:-*}" \ numpy="${NUMPY_VERSION:-*}" scipy="${SCIPY_VERSION:-*}" \ cython="${CYTHON_VERSION:-*}" pytest coverage \ matplotlib="${MATPLOTLIB_VERSION:-*}" sphinx=2.1.2 pillow \ scikit-image="${SCIKIT_IMAGE_VERSION:-*}" pandas="${PANDAS_VERSION:-*}" \ - joblib memory_profiler + joblib memory_profiler packaging source activate testenv -pip install sphinx-gallery==0.3.1 -pip install numpydoc==0.9 +pip install sphinx-gallery +pip install numpydoc # Build and install scikit-learn in dev mode python setup.py build_ext --inplace -j 3 @@ -169,14 +208,46 @@ affected_doc_paths() { fi } +affected_doc_warnings() { + files=$(git diff --name-only origin/master...$CIRCLE_SHA1) + # Look for sphinx warnings only in files affected by the PR + if [ -n "$files" ] + then + for af in ${files[@]} + do + warn+=`grep WARNING ~/log.txt | grep $af` + done + fi + echo "$warn" +} + if [ -n "$CI_PULL_REQUEST" ] then + echo "The following documentation warnings may have been generated by PR #$CI_PULL_REQUEST:" + warnings=$(affected_doc_warnings) + if [ -z "$warnings" ] + then + warnings="/home/circleci/project/ no warnings" + fi + echo "$warnings" + echo "The following documentation files may have been changed by PR #$CI_PULL_REQUEST:" affected=$(affected_doc_paths) echo "$affected" ( echo '
    ' echo "$affected" | sed 's|.*|
  • & [dev, stable]
  • |' - echo '

General: Home | API Reference | Examples

' + echo '

General: Home | API Reference | Examples

' + echo 'Sphinx Warnings in affected files
    ' + echo "$warnings" | sed 's/\/home\/circleci\/project\//
  • /g' + echo '
' ) > 'doc/_build/html/stable/_changed.html' + + if [ "$warnings" != "/home/circleci/project/ no warnings" ] + then + echo "Sphinx generated warnings when building the documentation related to files modified in this PR." + echo "Please check doc/_build/html/stable/_changed.html" + exit 1 + fi fi + diff --git a/build_tools/circle/build_test_pypy.sh b/build_tools/circle/build_test_pypy.sh index 60b81e60709f0..22e4790e7e4ab 100755 --- a/build_tools/circle/build_test_pypy.sh +++ b/build_tools/circle/build_test_pypy.sh @@ -18,11 +18,14 @@ source pypy-env/bin/activate python --version which python -# XXX: numpy version pinning can be reverted once PyPy -# compatibility is resolved for numpy v1.6.x. For instance, -# when PyPy3 >6.0 is released (see numpy/numpy#12740) -pip install --extra-index https://antocuni.github.io/pypy-wheels/ubuntu numpy Cython pytest -pip install scipy sphinx numpydoc docutils joblib pillow +pip install -U pip + +# pins versions to install wheel from https://antocuni.github.io/pypy-wheels/manylinux2010 +pip install --extra-index-url https://antocuni.github.io/pypy-wheels/manylinux2010 numpy==1.18.0 scipy==1.3.2 + +# Install Cython directly +pip install https://antocuni.github.io/pypy-wheels/ubuntu/Cython/Cython-0.29.14-py3-none-any.whl +pip install sphinx numpydoc docutils joblib pillow pytest ccache -M 512M export CCACHE_COMPRESS=1 diff --git a/build_tools/circle/check_deprecated_properties.sh b/build_tools/circle/check_deprecated_properties.sh deleted file mode 100755 index 8cbb97c774e21..0000000000000 --- a/build_tools/circle/check_deprecated_properties.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -# For docstrings and warnings of deprecated attributes to be rendered -# properly, the property decorator must come before the deprecated decorator -# (else they are treated as functions) -bad_deprecation_property_order=`git grep -A 10 "@property" | awk '/@property/,/def /' | grep -B1 "@deprecated"` -# exclude this file from the matches -bad_deprecation_property_order=`echo $bad_deprecation_property_order | grep -v check_deprecated_properties` - -if [ ! -z "$bad_deprecation_property_order" ] -then - echo "property decorator should come before deprecated decorator" - echo "found the following occurrencies:" - echo $bad_deprecation_property_order - exit 1 -fi diff --git a/build_tools/circle/flake8_diff.sh b/build_tools/circle/linting.sh similarity index 90% rename from build_tools/circle/flake8_diff.sh rename to build_tools/circle/linting.sh index 7a7fe7f12f241..2b408031c2eb6 100755 --- a/build_tools/circle/flake8_diff.sh +++ b/build_tools/circle/linting.sh @@ -143,3 +143,19 @@ else --config ./examples/.flake8 fi echo -e "No problem detected by flake8\n" + +# For docstrings and warnings of deprecated attributes to be rendered +# properly, the property decorator must come before the deprecated decorator +# (else they are treated as functions) + +# do not error when grep -B1 "@property" finds nothing +set +e +bad_deprecation_property_order=`git grep -A 10 "@property" -- "*.py" | awk '/@property/,/def /' | grep -B1 "@deprecated"` + +if [ ! -z "$bad_deprecation_property_order" ] +then + echo "property decorator should come before deprecated decorator" + echo "found the following occurrencies:" + echo $bad_deprecation_property_order + exit 1 +fi diff --git a/build_tools/circle/list_versions.py b/build_tools/circle/list_versions.py index c7b96abee852b..19fa8aa2dc991 100755 --- a/build_tools/circle/list_versions.py +++ b/build_tools/circle/list_versions.py @@ -49,7 +49,7 @@ def get_pdf_size(version): print() ROOT_URL = 'https://api.github.com/repos/scikit-learn/scikit-learn.github.io/contents/' # noqa -RAW_FMT = 'https://raw.githubusercontent.com/scikit-learn/scikit-learn.github.io/master/%s/documentation.html' # noqa +RAW_FMT = 'https://raw.githubusercontent.com/scikit-learn/scikit-learn.github.io/master/%s/index.html' # noqa VERSION_RE = re.compile(r"scikit-learn ([\w\.\-]+) documentation") NAMED_DIRS = ['dev', 'stable'] @@ -88,8 +88,8 @@ def get_pdf_size(version): else: seen.add(version_num) name_display = '' if name[:1].isdigit() else ' (%s)' % name - path = 'http://scikit-learn.org/%s' % name - out = ('* `Scikit-learn %s%s documentation <%s/documentation.html>`_' + path = 'https://scikit-learn.org/%s/' % name + out = ('* `Scikit-learn %s%s documentation <%s>`_' % (version_num, name_display, path)) if pdf_size is not None: out += (' (`PDF %s <%s/_downloads/scikit-learn-docs.pdf>`_)' diff --git a/build_tools/generate_authors_table.py b/build_tools/generate_authors_table.py index 3627875cc5656..81e99856c6890 100644 --- a/build_tools/generate_authors_table.py +++ b/build_tools/generate_authors_table.py @@ -10,14 +10,15 @@ import requests import getpass import time +from pathlib import Path print("user:", file=sys.stderr) user = input() passwd = getpass.getpass("Password or access token:\n") auth = (user, passwd) -ROW_SIZE = 7 LOGO_URL = 'https://avatars2.githubusercontent.com/u/365630?v=4' +REPO_FOLDER = Path(__file__).parent.parent def get(url): @@ -34,18 +35,6 @@ def get(url): return reply -def group_iterable(iterable, size): - """Group iterable into lines""" - group = [] - for element in iterable: - group.append(element) - if len(group) == size: - yield group - group = [] - if len(group) != 0: - yield group - - def get_contributors(): """Get the list of contributor profiles. Require admin rights.""" # get members of scikit-learn core-dev on GitHub @@ -120,33 +109,28 @@ def get_profile(login): def key(profile): - """Get the last name in lower case""" - return profile["name"].split(' ')[-1].lower() + """Get a sorting key based on the lower case last name, then firstname""" + components = profile["name"].lower().split(' ') + return " ".join([components[-1]] + components[:-1]) def generate_table(contributors): lines = [ (".. raw :: html\n"), (" "), - (" "), - (" " % - (int(100 / ROW_SIZE), ROW_SIZE)), + ("
"), (" "), ] - for row in group_iterable(contributors, size=ROW_SIZE): - lines.append("
") - for contributor in row: - lines.append(" ") - lines.append(" ") - lines.append("
") - lines.append( - "
" % - (contributor["html_url"], contributor["avatar_url"])) - lines.append("

%s

" % (contributor["name"], )) - lines.append("
") + for contributor in contributors: + lines.append("
") + lines.append( + "
" % + (contributor["html_url"], contributor["avatar_url"])) + lines.append("

%s

" % (contributor["name"], )) + lines.append("
") + lines.append(" ") return '\n'.join(lines) @@ -161,8 +145,8 @@ def generate_list(contributors): core_devs, emeritus = get_contributors() - with open("../doc/authors.rst", "w+") as rst_file: + with open(REPO_FOLDER / "doc" / "authors.rst", "w+") as rst_file: rst_file.write(generate_table(core_devs)) - with open("../doc/authors_emeritus.rst", "w+") as rst_file: + with open(REPO_FOLDER / "doc" / "authors_emeritus.rst", "w+") as rst_file: rst_file.write(generate_list(emeritus)) diff --git a/build_tools/travis/install.sh b/build_tools/travis/install.sh index a0481025931ba..6bb15b3f539e1 100755 --- a/build_tools/travis/install.sh +++ b/build_tools/travis/install.sh @@ -16,133 +16,53 @@ set -e # Fail fast build_tools/travis/travis_fastfail.sh -echo 'List files from cached directories' -echo 'pip:' +echo "List files from cached directories" +echo "pip:" ls $HOME/.cache/pip -if [ $TRAVIS_OS_NAME = "linux" ] -then - export CC=/usr/lib/ccache/gcc - export CXX=/usr/lib/ccache/g++ - # Useful for debugging how ccache is used - # export CCACHE_LOGFILE=/tmp/ccache.log - # ~60M is used by .ccache when compiling from scratch at the time of writing - ccache --max-size 100M --show-stats -elif [ $TRAVIS_OS_NAME = "osx" ] -then - # enable OpenMP support for Apple-clang - export CC=/usr/bin/clang - export CXX=/usr/bin/clang++ - export CPPFLAGS="$CPPFLAGS -Xpreprocessor -fopenmp" - export CFLAGS="$CFLAGS -I/usr/local/opt/libomp/include" - export CXXFLAGS="$CXXFLAGS -I/usr/local/opt/libomp/include" - export LDFLAGS="$LDFLAGS -L/usr/local/opt/libomp/lib -lomp" - export DYLD_LIBRARY_PATH=/usr/local/opt/libomp/lib -fi - -make_conda() { - TO_INSTALL="$@" - # Deactivate the travis-provided virtual environment and setup a - # conda-based environment instead - # If Travvis has language=generic, deactivate does not exist. `|| :` will pass. - deactivate || : - - # Install miniconda - if [ $TRAVIS_OS_NAME = "osx" ] - then - fname=Miniconda3-latest-MacOSX-x86_64.sh - else - fname=Miniconda3-latest-Linux-x86_64.sh - fi - wget https://repo.continuum.io/miniconda/$fname \ - -O miniconda.sh - MINICONDA_PATH=$HOME/miniconda - chmod +x miniconda.sh && ./miniconda.sh -b -p $MINICONDA_PATH - export PATH=$MINICONDA_PATH/bin:$PATH - conda update --yes conda - - conda create -n testenv --yes $TO_INSTALL - source activate testenv -} - -if [[ "$DISTRIB" == "conda" ]]; then - TO_INSTALL="python=$PYTHON_VERSION pip pytest pytest-cov \ - numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION \ - cython=$CYTHON_VERSION" - - if [[ "$INSTALL_MKL" == "true" ]]; then - TO_INSTALL="$TO_INSTALL mkl" - else - TO_INSTALL="$TO_INSTALL nomkl" - fi - - if [[ -n "$PANDAS_VERSION" ]]; then - TO_INSTALL="$TO_INSTALL pandas=$PANDAS_VERSION" - fi - - if [[ -n "$PYAMG_VERSION" ]]; then - TO_INSTALL="$TO_INSTALL pyamg=$PYAMG_VERSION" - fi - - if [[ -n "$PILLOW_VERSION" ]]; then - TO_INSTALL="$TO_INSTALL pillow=$PILLOW_VERSION" - fi - - if [[ -n "$JOBLIB_VERSION" ]]; then - TO_INSTALL="$TO_INSTALL joblib=$JOBLIB_VERSION" - fi - make_conda $TO_INSTALL - -elif [[ "$DISTRIB" == "ubuntu" ]]; then - # At the time of writing numpy 1.9.1 is included in the travis - # virtualenv but we want to use the numpy installed through apt-get - # install. - deactivate - # Create a new virtualenv using system site packages for python, numpy - # and scipy - virtualenv --system-site-packages --python=python3 testvenv - source testvenv/bin/activate - pip install pytest pytest-cov cython joblib==$JOBLIB_VERSION - -elif [[ "$DISTRIB" == "scipy-dev" ]]; then - make_conda python=3.7 - pip install --upgrade pip setuptools - - echo "Installing numpy and scipy master wheels" - dev_url=https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com - pip install --pre --upgrade --timeout=60 -f $dev_url numpy scipy pandas cython - echo "Installing joblib master" - pip install https://github.com/joblib/joblib/archive/master.zip - echo "Installing pillow master" - pip install https://github.com/python-pillow/Pillow/archive/master.zip - pip install pytest==4.6.4 pytest-cov -fi - -if [[ "$COVERAGE" == "true" ]]; then - pip install coverage codecov -fi - -if [[ "$TEST_DOCSTRINGS" == "true" ]]; then - pip install sphinx numpydoc # numpydoc requires sphinx -fi +export CC=/usr/lib/ccache/gcc +export CXX=/usr/lib/ccache/g++ +# Useful for debugging how ccache is used +# export CCACHE_LOGFILE=/tmp/ccache.log +# ~60M is used by .ccache when compiling from scratch at the time of writing +ccache --max-size 100M --show-stats + +# Deactivate the travis-provided virtual environment and setup a +# conda-based environment instead +# If Travvis has language=generic, deactivate does not exist. `|| :` will pass. +deactivate || : + +# Install miniconda +fname=Miniconda3-latest-Linux-x86_64.sh +wget https://repo.continuum.io/miniconda/$fname -O miniconda.sh +MINICONDA_PATH=$HOME/miniconda +chmod +x miniconda.sh && ./miniconda.sh -b -p $MINICONDA_PATH +export PATH=$MINICONDA_PATH/bin:$PATH +conda update --yes conda + +# Create environment and install dependencies +conda create -n testenv --yes python=3.7 +source activate testenv + +pip install --upgrade pip setuptools +echo "Installing numpy and scipy master wheels" +dev_url=https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com +pip install --pre --upgrade --timeout=60 -f $dev_url numpy scipy pandas cython +echo "Installing joblib master" +pip install https://github.com/joblib/joblib/archive/master.zip +echo "Installing pillow master" +pip install https://github.com/python-pillow/Pillow/archive/master.zip +pip install pytest==4.6.4 pytest-cov # Build scikit-learn in the install.sh script to collapse the verbose # build output in the travis output when it succeeds. python --version python -c "import numpy; print('numpy %s' % numpy.__version__)" python -c "import scipy; print('scipy %s' % scipy.__version__)" -python -c "\ -try: - import pandas - print('pandas %s' % pandas.__version__) -except ImportError: - pass -" + python setup.py develop -if [ $TRAVIS_OS_NAME = "linux" ] -then - ccache --show-stats -fi + +ccache --show-stats # Useful for debugging how ccache is used # cat $CCACHE_LOGFILE diff --git a/conftest.py b/conftest.py index 73326d6d2e32b..b98bb4b271aca 100644 --- a/conftest.py +++ b/conftest.py @@ -7,12 +7,15 @@ import platform from distutils.version import LooseVersion +import os import pytest from _pytest.doctest import DoctestItem from sklearn import set_config from sklearn.utils import _IS_32BIT +from sklearn.externals import _pilutil +from sklearn._build_utils.deprecated_modules import _DEPRECATED_MODULES PYTEST_MIN_VERSION = '3.3.0' @@ -34,9 +37,8 @@ def pytest_collection_modifyitems(config, items): skip_marker = pytest.mark.skip( reason='FeatureHasher is not compatible with PyPy') for item in items: - if item.name in ( - 'sklearn.feature_extraction.hashing.FeatureHasher', - 'sklearn.feature_extraction.text.HashingVectorizer'): + if item.name.endswith(('_hash.FeatureHasher', + 'text.HashingVectorizer')): item.add_marker(skip_marker) # Skip tests which require internet if the flag is provided @@ -68,6 +70,13 @@ def pytest_collection_modifyitems(config, items): for item in items: if isinstance(item, DoctestItem): item.add_marker(skip_marker) + elif not _pilutil.pillow_installed: + skip_marker = pytest.mark.skip(reason="pillow (or PIL) not installed!") + for item in items: + if item.name in [ + "sklearn.feature_extraction.image.PatchExtractor", + "sklearn.feature_extraction.image.extract_patches_2d"]: + item.add_marker(skip_marker) def pytest_configure(config): @@ -88,3 +97,10 @@ def pytest_runtest_setup(item): def pytest_runtest_teardown(item, nextitem): if isinstance(item, DoctestItem): set_config(print_changed_only=False) + + +# TODO: Remove when modules are deprecated in 0.24 +# Configures pytest to ignore deprecated modules. +collect_ignore_glob = [ + os.path.join(*deprecated_path.split(".")) + ".py" + for _, deprecated_path, _, _ in _DEPRECATED_MODULES] diff --git a/doc/Makefile b/doc/Makefile index 6629518fc556a..1cbce7dba9662 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -2,7 +2,7 @@ # # You can set these variables from the command line. -SPHINXOPTS = +SPHINXOPTS = -j auto SPHINXBUILD ?= sphinx-build PAPER = BUILDDIR = _build @@ -98,7 +98,7 @@ doctest: "results in $(BUILDDIR)/doctest/output.txt." download-data: - python -c "from sklearn.datasets.lfw import _check_fetch_lfw; _check_fetch_lfw()" + python -c "from sklearn.datasets._lfw import _check_fetch_lfw; _check_fetch_lfw()" # Optimize PNG files. Needs OptiPNG. Change the -P argument to the number of # cores you have available, so -P 64 if you have a real computer ;) diff --git a/doc/about.rst b/doc/about.rst index bb628469d239b..2008d96af0045 100644 --- a/doc/about.rst +++ b/doc/about.rst @@ -96,86 +96,346 @@ following paper: Artwork ------- -High quality PNG and SVG logos are available in the `doc/logos/ `_ source directory. +High quality PNG and SVG logos are available in the `doc/logos/ +`_ +source directory. .. image:: images/scikit-learn-logo-notext.png :align: center Funding ------- +Scikit-Learn is a community driven project, however institutional and private +grants help to assure its sustainability. + +The project would like to thank the following funders. + +................................... + +.. raw:: html + +
+
+ +The `Members `_ of +the `Scikit-Learn Consortium at Inria Foundation +`_ fund Olivier +Grisel, Guillaume Lemaitre, Jérémie du Boisberranger and Chiara Marmo. + +.. raw:: html + +
+ +.. |msn| image:: images/microsoft.png + :width: 100pt + :target: https://www.microsoft.com/ + +.. |bcg| image:: images/bcg.png + :width: 100pt + :target: https://www.bcg.com/beyond-consulting/bcg-gamma/default.aspx + +.. |axa| image:: images/axa.png + :width: 50pt + :target: https://www.axa.fr/ + +.. |bnp| image:: images/bnp.png + :width: 170pt + :target: https://www.bnpparibascardif.com/ + +.. |fujitsu| image:: images/fujitsu.png + :width: 100pt + :target: https://www.fujitsu.com/global/ + +.. |intel| image:: images/intel.png + :width: 70pt + :target: https://www.intel.com/ + +.. |nvidia| image:: images/nvidia.png + :width: 70pt + :target: https://www.nvidia.com/ + +.. |dataiku| image:: images/dataiku.png + :width: 70pt + :target: https://www.dataiku.com/ + +.. |inria| image:: images/inria-logo.jpg + :width: 100pt + :target: https://www.inria.fr + + +.. raw:: html + +
+ +.. table:: + :class: sk-sponsor-table align-default + + +---------+----------+ + | |msn| | |bcg| | + +---------+----------+ + | | + +---------+----------+ + | |axa| ||fujitsu| | + +---------+----------+ + | |bnp| | + +---------+----------+ + | |intel| | |nvidia| | + +---------+----------+ + | | + +---------+----------+ + ||dataiku|| |inria| | + +---------+----------+ + +.. raw:: html + +
+
+ +........ + +.. raw:: html + +
+
+ +`Columbia University `_ funds Andreas Müller since 2016 + +.. raw:: html + +
+ +
+ +.. image:: themes/scikit-learn/static/img/columbia.png + :width: 50pt + :align: center + :target: https://www.columbia.edu/ + +.. raw:: html + +
+
+ +.......... + +.. raw:: html + +
+
+ +Andreas Müller received a grant to improve scikit-learn from the +`Alfred P. Sloan Foundation `_ . +This grant supports the position of Nicolas Hug and Thomas J. Fan. + +.. raw:: html + +
+ +
+ +.. image:: images/sloan_banner.png + :width: 100pt + :align: center + :target: https://sloan.org/ + +.. raw:: html + +
+
+ +........... + +.. raw:: html + +
+
+ +`The University of Sydney `_ funds Joel Nothman since +July 2017. + +.. raw:: html + +
+ +
+ +.. image:: themes/scikit-learn/static/img/sydney-primary.jpeg + :width: 100pt + :align: center + :target: https://sydney.edu.au/ + +.. raw:: html + +
+
+ +............ + +.. raw:: html + +
+
+ +`Anaconda, Inc `_ funds Adrin Jalali since 2019. + +.. raw:: html + +
+ +
+ +.. image:: images/anaconda.png + :width: 100pt + :align: center + :target: https://sydney.edu.au/ + +.. raw:: html + +
+
+ +Past Sponsors +............. + +.. raw:: html + +
+
`INRIA `_ actively supports this project. It has provided funding for Fabian Pedregosa (2010-2012), Jaques Grobler (2012-2013) and Olivier Grisel (2013-2017) to work on this project full-time. It also hosts coding sprints and other events. +.. raw:: html + +
+ +
+ .. image:: images/inria-logo.jpg - :width: 200pt + :width: 100pt :align: center :target: https://www.inria.fr -`Paris-Saclay Center for Data Science `_ +.. raw:: html + +
+
+ +..................... + +.. raw:: html + +
+
+ +`Paris-Saclay Center for Data Science +`_ funded one year for a developer to work on the project full-time -(2014-2015) and 50% of the time of Guillaume Lemaitre (2016-2017). +(2014-2015), 50% of the time of Guillaume Lemaitre (2016-2017) and 50% of the +time of Joris van den Bossche (2017-2018). + +.. raw:: html + +
+
.. image:: images/cds-logo.png - :width: 200pt + :width: 100pt :align: center :target: https://www.datascience-paris-saclay.fr/ +.. raw:: html + +
+
+ +.......................... + +.. raw:: html + +
+
+ `NYU Moore-Sloan Data Science Environment `_ -funded Andreas Mueller (2014-2016) to work on this project. The Moore-Sloan Data Science -Environment also funds several students to work on the project part-time. +funded Andreas Mueller (2014-2016) to work on this project. The Moore-Sloan +Data Science Environment also funds several students to work on the project +part-time. + +.. raw:: html + +
+
.. image:: images/nyu_short_color.png - :width: 200pt + :width: 100pt :align: center :target: https://cds.nyu.edu/mooresloan/ +.. raw:: html + +
+
-`Télécom Paristech `_ funded Manoj Kumar (2014), -Tom Dupré la Tour (2015), Raghav RV (2015-2017), Thierry Guillemot (2016-2017) -and Albert Thomas (2017) to work on scikit-learn. +........................ + +.. raw:: html + +
+
+ +`Télécom Paristech `_ funded Manoj Kumar +(2014), Tom Dupré la Tour (2015), Raghav RV (2015-2017), Thierry Guillemot +(2016-2017) and Albert Thomas (2017) to work on scikit-learn. + +.. raw:: html + +
+
.. image:: themes/scikit-learn/static/img/telecom.png - :width: 100pt + :width: 50pt :align: center :target: https://www.telecom-paristech.fr/ +.. raw:: html -`Columbia University `_ funds Andreas Müller since 2016. +
+
-.. image:: themes/scikit-learn/static/img/columbia.png - :width: 100pt - :align: center - :target: https://www.columbia.edu/ +..................... -Andreas Müller also received a grant to improve scikit-learn from the `Alfred P. Sloan Foundation `_ in 2017. +.. raw:: html -.. image:: images/sloan_banner.png - :width: 200pt - :align: center - :target: https://sloan.org/ +
+
-`The University of Sydney `_ funds Joel Nothman since July 2017. +`The Labex DigiCosme `_ funded Nicolas Goix +(2015-2016), Tom Dupré la Tour (2015-2016 and 2017-2018), Mathurin Massias +(2018-2019) to work part time on scikit-learn during their PhDs. It also +funded a scikit-learn coding sprint in 2015. -.. image:: themes/scikit-learn/static/img/sydney-primary.jpeg - :width: 200pt - :align: center - :target: https://sydney.edu.au/ +.. raw:: html -`The Labex DigiCosme `_ funded Nicolas Goix (2015-2016), -Tom Dupré la Tour (2015-2016 and 2017-2018), Mathurin Massias (2018-2019) to work part time -on scikit-learn during their PhDs. It also funded a scikit-learn coding sprint in 2015. +
+
.. image:: themes/scikit-learn/static/img/digicosme.png - :width: 200pt + :width: 100pt :align: center :target: https://digicosme.lri.fr -The following students were sponsored by `Google `_ -to work on scikit-learn through the -`Google Summer of Code `_ +.. raw:: html + +
+
+ +...................... + +The following students were sponsored by `Google +`_ to work on scikit-learn through +the `Google Summer of Code `_ program. - 2007 - David Cournapeau @@ -188,29 +448,43 @@ program. .. _Vlad Niculae: https://vene.ro/ -It also provided funding for sprints and events around scikit-learn. If -you would like to participate in the next Google Summer of code -program, please see `this page -`_. +................... The `NeuroDebian `_ project providing `Debian `_ packaging and contributions is supported by `Dr. James V. Haxby `_ (`Dartmouth College `_). -The `PSF `_ helped find and manage funding for our -2011 Granada sprint. More information can be found `here -`__ +Sprints +------- + +The International 2019 Paris sprint was kindly hosted by `AXA `_. +Also some participants could attend thanks to the support of the `Alfred P. +Sloan Foundation `_, the `Python Software +Foundation `_ (PSF) and the `DATAIA Institute +`_. + +..................... -`tinyclues `_ funded the 2011 international Granada -sprint. +The 2013 International Paris Sprint was made possible thanks to the support of +`Télécom Paristech `_, `tinyclues +`_, the `French Python Association +`_ and the `Fonds de la Recherche Scientifique +`_. +.............. + +The 2011 International Granada sprint was made possible thanks to the support +of the `PSF `_ and `tinyclues +`_. Donating to the project -~~~~~~~~~~~~~~~~~~~~~~~ +....................... -If you are interested in donating to the project or to one of our code-sprints, you can use -the *Paypal* button below or the `NumFOCUS Donations Page `_ (if you use the latter, please indicate that you are donating for the scikit-learn project). +If you are interested in donating to the project or to one of our code-sprints, +you can use the *Paypal* button below or the `NumFOCUS Donations Page +`_ (if you use the latter, +please indicate that you are donating for the scikit-learn project). All donations will be handled by `NumFOCUS `_, a non-profit-organization which is @@ -220,8 +494,9 @@ scientific computing software, in particular in Python. As a fiscal home of scikit-learn, it ensures that money is available when needed to keep the project funded and available while in compliance with tax regulations. -The received donations for the scikit-learn project mostly will go towards covering travel-expenses -for code sprints, as well as towards the organization budget of the project [#f1]_. +The received donations for the scikit-learn project mostly will go towards +covering travel-expenses for code sprints, as well as towards the organization +budget of the project [#f1]_. .. raw :: html @@ -243,69 +518,20 @@ for code sprints, as well as towards the organization budget of the project [#f1 .. rubric:: Notes -.. [#f1] Regarding the organization budget in particular, we might use some of the donated funds to pay for other project expenses such as DNS, hosting or continuous integration services. - - -The 2013 Paris international sprint -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -|center-div| |telecom| |tinyclues| |afpy| |FNRS| - - |end-div| - - - -.. |center-div| raw:: html - -
- - -.. |telecom| image:: themes/scikit-learn/static/img/telecom.png - :width: 120pt - :target: https://www.telecom-paristech.fr/ - - -.. |tinyclues| image:: https://www.tinyclues.com/web/wp-content/uploads/2016/06/Tinyclues-PNG-logo.png - :width: 120pt - :target: https://www.tinyclues.com/ - - -.. |afpy| image:: https://www.afpy.org/static/images/logo.svg - :width: 120pt - :target: https://www.afpy.org - - -.. |SGR| image:: http://www.svi.cnrs-bellevue.fr/wikimedia/images/Logo_svi_inp.png - :width: 120pt - :target: http://www.svi.cnrs-bellevue.fr - -.. |FNRS| image:: http://www.fnrs.be/en/images/FRS-FNRS_rose_transp.png - :width: 120pt - :target: http://www.frs-fnrs.be/ - -.. figure:: images/dysco.png - :width: 120pt - :target: https://sites.uclouvain.be/dysco/ - - IAP VII/19 - DYSCO - -.. |end-div| raw:: html - -
- -*For more information on this sprint, see* `here -`__ - +.. [#f1] Regarding the organization budget in particular, we might use some of + the donated funds to pay for other project expenses such as DNS, + hosting or continuous integration services. Infrastructure support ---------------------- - We would like to thank `Rackspace `_ for providing - us with a free `Rackspace Cloud `_ account to - automatically build the documentation and the example gallery from for the + us with a free `Rackspace Cloud `_ account + to automatically build the documentation and the example gallery from for the development version of scikit-learn using `this tool `_. -- We would also like to thank `Shining Panda - `_ for free CPU time on their Continuous - Integration server. +- We would also like to thank `Microsoft Azure + `_, `Travis Cl `_, + `CircleCl `_ for free CPU time on their Continuous + Integration servers. diff --git a/doc/authors.rst b/doc/authors.rst index 1a0e9363ec97c..6a03871d67e90 100644 --- a/doc/authors.rst +++ b/doc/authors.rst @@ -1,96 +1,88 @@ .. raw :: html - - +
-
- - - - - - - - - - - - - - - - - - - - - - - - - -
+

Jérémie Du Boisberranger

-
+ +

Joris Van den Bossche

-
+ +

Loïc Estève

-
+ +

Thomas J Fan

-
+ +

Alexandre Gramfort

-
+ +

Olivier Grisel

-
+ +

Yaroslav Halchenko

-
+ +

Nicolas Hug

-
+ +

Adrin Jalali

-
+ +

Guillaume Lemaitre

-
+ +

Jan Hendrik Metzen

-
+ +

Andreas Mueller

-
+ +

Vlad Niculae

-
+ +

Joel Nothman

-
+ +

Hanmin Qin

-
+ +

Bertrand Thirion

-
+ +

Tom Dupré la Tour

-
-
-

Nelle Varoquaux

-
+ +

Gael Varoquaux

-
+ +
+
+

Nelle Varoquaux

+
+

Roman Yurchak

-
\ No newline at end of file + + \ No newline at end of file diff --git a/doc/authors_emeritus.rst b/doc/authors_emeritus.rst index 5eb0ccf0a8cef..bcfd7d7d0514c 100644 --- a/doc/authors_emeritus.rst +++ b/doc/authors_emeritus.rst @@ -16,7 +16,7 @@ - Arnaud Joly - Thouis (Ray) Jones - Kyle Kastner -- Manoj Kumar +- manoj kumar - Robert Layton - Wei Li - Paolo Losi diff --git a/doc/conf.py b/doc/conf.py index ef89cb7fb0a35..c4d7e578216fd 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -16,6 +16,7 @@ import os import warnings import re +from packaging.version import parse # If extensions (or modules to document with autodoc) are in another # directory, add these directories to sys.path here. If the directory @@ -50,11 +51,11 @@ if os.environ.get('NO_MATHJAX'): extensions.append('sphinx.ext.imgmath') imgmath_image_format = 'svg' + mathjax_path = '' else: extensions.append('sphinx.ext.mathjax') - mathjax_path = ('https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/' - 'MathJax.js?config=TeX-AMS_SVG') - + mathjax_path = ('https://cdn.jsdelivr.net/npm/mathjax@3/es5/' + 'tex-chtml.js') autodoc_default_options = { 'members': True, @@ -74,7 +75,7 @@ #source_encoding = 'utf-8' # The master toctree document. -master_doc = 'index' +master_doc = 'contents' # General information about the project. project = 'scikit-learn' @@ -86,7 +87,7 @@ # # The short X.Y version. import sklearn -version = sklearn.__version__ +version = parse(sklearn.__version__).base_version # The full version, including alpha/beta/rc tags. release = sklearn.__version__ @@ -130,14 +131,13 @@ # The theme to use for HTML and HTML Help pages. Major themes that come with # Sphinx are currently 'default' and 'sphinxdoc'. -html_theme = 'scikit-learn' +html_theme = 'scikit-learn-modern' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -html_theme_options = {'oldversion': False, 'collapsiblesidebar': True, - 'google_analytics': True, 'surveybanner': False, - 'sprintbanner': True, 'body_max_width': None} +html_theme_options = {'google_analytics': True, + 'mathjax_path': mathjax_path} # Add any paths that contain custom themes here, relative to this directory. html_theme_path = ['themes'] @@ -173,7 +173,9 @@ # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +html_additional_pages = { + 'index': 'index.html', + 'documentation': 'documentation.html'} # redirects to index # If false, no module index is generated. html_domain_indices = False @@ -198,6 +200,8 @@ # Output file base name for HTML help builder. htmlhelp_basename = 'scikit-learndoc' +# If true, the reST sources are included in the HTML build as _sources/name. +html_copy_source = True # -- Options for LaTeX output ------------------------------------------------ latex_elements = { @@ -217,7 +221,7 @@ # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass # [howto/manual]). -latex_documents = [('index', 'user_guide.tex', 'scikit-learn user guide', +latex_documents = [('contents', 'user_guide.tex', 'scikit-learn user guide', 'scikit-learn developers', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of @@ -243,19 +247,48 @@ 'joblib': ('https://joblib.readthedocs.io/en/latest/', None), } -if 'dev' in version: +v = parse(release) +if v.release is None: + raise ValueError( + 'Ill-formed version: {!r}. Version should follow ' + 'PEP440'.format(version)) + +if v.is_devrelease: binder_branch = 'master' else: - match = re.match(r'^(\d+)\.(\d+)(?:\.\d+)?$', version) - if match is None: - raise ValueError( - 'Ill-formed version: {!r}. Expected either ' - "a version containing 'dev' " - 'or a version like X.Y or X.Y.Z.'.format(version)) - - major, minor = match.groups() + major, minor = v.release[:2] binder_branch = '{}.{}.X'.format(major, minor) + +class SubSectionTitleOrder: + """Sort example gallery by title of subsection. + + Assumes README.txt exists for all subsections and uses the subsection with + dashes, '---', as the adornment. + """ + def __init__(self, src_dir): + self.src_dir = src_dir + self.regex = re.compile(r"^([\w ]+)\n-", re.MULTILINE) + + def __repr__(self): + return '<%s>' % (self.__class__.__name__,) + + def __call__(self, directory): + src_path = os.path.normpath(os.path.join(self.src_dir, directory)) + readme = os.path.join(src_path, "README.txt") + + try: + with open(readme, 'r') as f: + content = f.read() + except FileNotFoundError: + return directory + + title_match = self.regex.search(content) + if title_match is not None: + return title_match.group(1) + return directory + + sphinx_gallery_conf = { 'doc_module': 'sklearn', 'backreferences_dir': os.path.join('modules', 'generated'), @@ -264,6 +297,7 @@ 'sklearn': None}, 'examples_dirs': ['../examples'], 'gallery_dirs': ['auto_examples'], + 'subsection_order': SubSectionTitleOrder('../examples'), 'binder': { 'org': 'scikit-learn', 'repo': 'scikit-learn', @@ -271,7 +305,9 @@ 'branch': binder_branch, 'dependencies': './binder/requirements.txt', 'use_jupyter_lab': True - } + }, + # avoid generating too many cross links + 'inspect_global_variables': False, } @@ -279,11 +315,7 @@ # thumbnails for the front page of the scikit-learn home page. # key: first image in set # values: (number of plot in set, height of thumbnail) -carousel_thumbs = {'sphx_glr_plot_classifier_comparison_001.png': 600, - 'sphx_glr_plot_anomaly_comparison_001.png': 372, - 'sphx_glr_plot_gpr_co2_001.png': 350, - 'sphx_glr_plot_adaboost_twoclass_001.png': 372, - 'sphx_glr_plot_compare_methods_001.png': 349} +carousel_thumbs = {'sphx_glr_plot_classifier_comparison_001.png': 600} # enable experimental module so that experimental estimators can be @@ -306,6 +338,27 @@ def make_carousel_thumbs(app, exception): sphinx_gallery.gen_rst.scale_image(image, c_thumb, max_width, 190) +def filter_search_index(app, exception): + if exception is not None: + return + + # searchindex only exist when generating html + if app.builder.name != 'html': + return + + print('Removing methods from search index') + + searchindex_path = os.path.join(app.builder.outdir, 'searchindex.js') + with open(searchindex_path, 'r') as f: + searchindex_text = f.read() + + searchindex_text = re.sub(r'{__init__.+?}', '{}', searchindex_text) + searchindex_text = re.sub(r'{__call__.+?}', '{}', searchindex_text) + + with open(searchindex_path, 'w') as f: + f.write(searchindex_text) + + # Config for sphinx_issues # we use the issues path for PRs since the issues URL will forward @@ -314,9 +367,8 @@ def make_carousel_thumbs(app, exception): def setup(app): # to hide/show the prompt in code examples: - app.add_javascript('js/copybutton.js') - app.add_javascript('js/extra.js') app.connect('build-finished', make_carousel_thumbs) + app.connect('build-finished', filter_search_index) # The following is used by sphinx.ext.linkcode to provide links to github diff --git a/doc/conftest.py b/doc/conftest.py index c66be1ef6deec..d1be865135e76 100644 --- a/doc/conftest.py +++ b/doc/conftest.py @@ -6,11 +6,11 @@ import numpy as np from sklearn.utils import IS_PYPY -from sklearn.utils.testing import SkipTest -from sklearn.utils.testing import check_skip_network +from sklearn.utils._testing import SkipTest +from sklearn.utils._testing import check_skip_network from sklearn.datasets import get_data_home -from sklearn.datasets.base import _pkl_filepath -from sklearn.datasets.twenty_newsgroups import CACHE_NAME +from sklearn.datasets._base import _pkl_filepath +from sklearn.datasets._twenty_newsgroups import CACHE_NAME def setup_labeled_faces(): diff --git a/doc/contents.rst b/doc/contents.rst new file mode 100644 index 0000000000000..a28634621d558 --- /dev/null +++ b/doc/contents.rst @@ -0,0 +1,24 @@ +.. include:: includes/big_toc_css.rst +.. include:: tune_toc.rst + +.. Places global toc into the sidebar + +:globalsidebartoc: True + +================= +Table Of Contents +================= + +.. Define an order for the Table of Contents: + +.. toctree:: + :maxdepth: 2 + + preface + tutorial/index + getting_started + user_guide + glossary + auto_examples/index + modules/classes + developers/index diff --git a/doc/data_transforms.rst b/doc/data_transforms.rst index 5b5c356324197..01547f68008b6 100644 --- a/doc/data_transforms.rst +++ b/doc/data_transforms.rst @@ -24,6 +24,7 @@ transformations of the target space (e.g. categorical labels) for use in scikit-learn. .. toctree:: + :maxdepth: 2 modules/compose modules/feature_extraction diff --git a/doc/developers/advanced_installation.rst b/doc/developers/advanced_installation.rst index 0eaac27699d37..8fd0f9ecf0273 100644 --- a/doc/developers/advanced_installation.rst +++ b/doc/developers/advanced_installation.rst @@ -1,31 +1,12 @@ .. _advanced-installation: -=================================== -Advanced installation instructions -=================================== +================================================== +Installing the development version of scikit-learn +================================================== -There are different ways to get scikit-learn installed: - - * :ref:`Install an official release `. This - is the best approach for most users. It will provide a stable version - and pre-build packages are available for most platforms. - - * Install the version of scikit-learn provided by your - :ref:`operating system or Python distribution `. - This is a quick option for those who have operating systems - that distribute scikit-learn. It might not provide the latest release - version. - - * :ref:`Building the package from source - `. This is best for users who want the - latest-and-greatest features and aren't afraid of running - brand-new code. This document describes how to build from source. - -.. note:: - - If you wish to contribute to the project, you need to - :ref:`install the latest development version`. +This section introduces how to install the **master branch** of scikit-learn. +This can be done by either installing a nightly build or building from source. .. _install_nightly_builds: @@ -34,7 +15,16 @@ Installing nightly builds The continuous integration servers of the scikit-learn project build, test and upload wheel packages for the most recent Python version on a nightly -basis to help users test bleeding edge features or bug fixes:: +basis. + +Installing a nightly build is the quickest way to: + +- try a new feature that will be shipped in the next release (that is, a + feature from a pull-request that was recently merged to the master branch); + +- check whether a bug you encountered has been fixed since the last release. + +:: pip install --pre -f https://sklearn-nightly.scdn8.secure.raxcdn.com scikit-learn @@ -42,250 +32,353 @@ basis to help users test bleeding edge features or bug fixes:: .. _install_bleeding_edge: Building from source -===================== +==================== + +Building from source is required to work on a contribution (bug fix, new +feature, code or documentation improvement). + +.. _git_repo: -In the vast majority of cases, building scikit-learn for development purposes -can be done with:: +#. Use `Git `_ to check out the latest source from the + `scikit-learn repository `_ on + Github.:: - pip install cython pytest flake8 + git clone git://github.com/scikit-learn/scikit-learn.git + cd scikit-learn -Then, in the main repository:: + If you plan on submitting a pull-request, you should clone from your fork + instead. - pip install --editable . +#. Install a compiler with OpenMP_ support for your platform. See instructions + for :ref:`compiler_windows`, :ref:`compiler_macos`, :ref:`compiler_linux` + and :ref:`compiler_freebsd`. -Please read below for details and more advanced instructions. +#. Optional (but recommended): create and activate a dedicated virtualenv_ + or `conda environment`_. + +#. Install Cython_ and build the project with pip in :ref:`editable_mode`:: + + pip install cython + pip install --verbose --editable . + +#. Check that the installed scikit-learn has a version number ending with + `.dev0`:: + + python -c "import sklearn; sklearn.show_versions()" + +#. Please refer to the :ref:`developers_guide` and :ref:`pytest_tips` to run + the tests on the module of your choice. + +.. note:: + + You will have to re-run the ``pip install --editable .`` command every time + the source code of a Cython file is updated (ending in `.pyx` or `.pxd`). Dependencies ------------ -Scikit-learn requires: +Runtime dependencies +~~~~~~~~~~~~~~~~~~~~ + +Scikit-learn requires the following dependencies both at build time and at +runtime: - Python (>= 3.5), - NumPy (>= 1.11), - SciPy (>= 0.17), - Joblib (>= 0.11). +Those dependencies are **automatically installed by pip** if they were missing +when building scikit-learn from source. + .. note:: - For installing on PyPy, PyPy3-v5.10+, Numpy 1.14.0+, and scipy 1.1.0+ + For running on PyPy, PyPy3-v5.10+, Numpy 1.14.0+, and scipy 1.1.0+ are required. For PyPy, only installation instructions with pip apply. +Build dependencies +~~~~~~~~~~~~~~~~~~ -Building Scikit-learn also requires +Building Scikit-learn also requires: -- Cython >=0.28.5 -- OpenMP +.. + # The following places need to be in sync with regard to Cython version: + # - .circleci config file + # - sklearn/_build_utils/__init__.py + # - advanced installation guide + +- Cython >= 0.28.5 +- A C/C++ compiler and a matching OpenMP_ runtime library. See the + :ref:`platform system specific instructions + ` for more details. .. note:: - It is possible to build scikit-learn without OpenMP support by setting the - ``SKLEARN_NO_OPENMP`` environment variable (before cythonization). This is - not recommended since it will force some estimators to run in sequential - mode and their ``n_jobs`` parameter will be ignored. + If OpenMP is not supported by the compiler, the build will be done with + OpenMP functionalities disabled. This is not recommended since it will force + some estimators to run in sequential mode instead of leveraging thread-based + parallelism. Setting the ``SKLEARN_FAIL_NO_OPENMP`` environment variable + (before cythonization) will force the build to fail if OpenMP is not + supported. + +Since version 0.21, scikit-learn automatically detects and use the linear +algebrea library used by SciPy **at runtime**. Scikit-learn has therefore no +build dependency on BLAS/LAPACK implementations such as OpenBlas, Atlas, Blis +or MKL. +Test dependencies +~~~~~~~~~~~~~~~~~ -Running tests requires +Running tests requires: -.. |PytestMinVersion| replace:: 3.3.0 +.. |PytestMinVersion| replace:: 4.6.2 - pytest >=\ |PytestMinVersion| Some tests also require `pandas `_. -.. _git_repo: -Retrieving the latest code --------------------------- - -We use `Git `_ for version control and -`GitHub `_ for hosting our main repository. - -You can check out the latest sources with the command:: - - git clone git://github.com/scikit-learn/scikit-learn.git +Building a specific version from a tag +-------------------------------------- If you want to build a stable version, you can ``git checkout `` to get the code for that particular version, or download an zip archive of the version from github. -Once you have all the build requirements installed (see below for details), -you can build and install the package in the following way. +.. _editable_mode: -If you run the development version, it is cumbersome to reinstall the -package each time you update the sources. Therefore it's recommended that you -install in editable mode, which allows you to edit the code in-place. This -builds the extension in place and creates a link to the development directory -(see `the pip docs `_):: +Editable mode +------------- - pip install --editable . +If you run the development version, it is cumbersome to reinstall the package +each time you update the sources. Therefore it is recommended that you install +in with the ``pip install --editable .`` command, which allows you to edit the +code in-place. This builds the extension in place and creates a link to the +development directory (see `the pip docs +`_). -.. note:: +This is fundamentally similar to using the command ``python setup.py develop`` +(see `the setuptool docs +`_). +It is however preferred to use pip. - This is fundamentally similar to using the command ``python setup.py develop`` - (see `the setuptool docs `_). - It is however preferred to use pip. +On Unix-like systems, you can equivalently type ``make in`` from the top-level +folder. Have a look at the ``Makefile`` for additional utilities. -.. note:: - - You will have to re-run:: +.. _platform_specific_instructions: - pip install --editable . +Platform-specific instructions +============================== - every time the source code of a compiled extension is changed (for - instance when switching branches or pulling changes from upstream). - Compiled extensions are Cython files (ending in `.pyx` or `.pxd`). +Here are instructions to install a working C/C++ compiler with OpenMP support +to build scikit-learn Cython extensions for each supported platform. -On Unix-like systems, you can equivalently type ``make in`` from the -top-level folder. Have a look at the ``Makefile`` for additional utilities. +.. _compiler_windows: -Mac OSX +Windows ------- -The default C compiler, Apple-clang, on Mac OSX does not directly support -OpenMP. The first solution to build scikit-learn is to install another C -compiler such as gcc or llvm-clang. Another solution is to enable OpenMP -support on the default Apple-clang. In the following we present how to -configure this second option. +First, install `Build Tools for Visual Studio 2019 +`_. -You first need to install the OpenMP library:: +.. warning:: - brew install libomp + You DO NOT need to install Visual Studio 2019. You only need the "Build + Tools for Visual Studio 2019", under "All downloads" -> "Tools for Visual + Studio 2019". -Then you need to set the following environment variables:: +Secondly, find out if you are running 64-bit or 32-bit Python. The building +command depends on the architecture of the Python interpreter. You can check +the architecture by running the following in ``cmd`` or ``powershell`` +console:: - export CC=/usr/bin/clang - export CXX=/usr/bin/clang++ - export CPPFLAGS="$CPPFLAGS -Xpreprocessor -fopenmp" - export CFLAGS="$CFLAGS -I/usr/local/opt/libomp/include" - export CXXFLAGS="$CXXFLAGS -I/usr/local/opt/libomp/include" - export LDFLAGS="$LDFLAGS -L/usr/local/opt/libomp/lib -lomp" - export DYLD_LIBRARY_PATH=/usr/local/opt/libomp/lib + python -c "import struct; print(struct.calcsize('P') * 8)" -Finally you can build the package using the standard command. +For 64-bit Python, configure the build environment with:: -FreeBSD -------- + SET DISTUTILS_USE_SDK=1 + "C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC\Auxiliary\Build\vcvarsall.bat" x64 -The clang compiler included in FreeBSD 12.0 and 11.2 base systems does not -include OpenMP support. You need to install the `openmp` library from packages -(or ports):: +Replace ``x64`` by ``x86`` to build for 32-bit Python. - sudo pkg install openmp - -This will install header files in ``/usr/local/include`` and libs in -``/usr/local/lib``. Since these directories are not searched by default, you -can set the environment variables to these locations:: +Please be aware that the path above might be different from user to user. The +aim is to point to the "vcvarsall.bat" file that will set the necessary +environment variables in the current command prompt. - export CFLAGS="$CFLAGS -I/usr/local/include" - export CXXFLAGS="$CXXFLAGS -I/usr/local/include" - export LDFLAGS="$LDFLAGS -L/usr/local/lib -lomp" - export DYLD_LIBRARY_PATH=/usr/local/lib +Finally, build scikit-learn from this command prompt:: -Finally you can build the package using the standard command. + pip install --verbose --editable . -For the upcomming FreeBSD 12.1 and 11.3 versions, OpenMP will be included in -the base system and these steps will not be necessary. +.. _compiler_macos: +macOS +----- -Installing build dependencies -============================= +The default C compiler on macOS, Apple clang (confusingly aliased as +`/usr/bin/gcc`), does not directly support OpenMP. We present two alternatives +to enable OpenMP support: -Linux ------ +- either install `conda-forge::compilers` with conda; + +- or install `libomp` with Homebrew to extend the default Apple clang compiler. -Installing from source without conda requires you to have installed the -scikit-learn runtime dependencies, Python development headers and a working -C/C++ compiler. Under Debian-based operating systems, which include Ubuntu:: +macOS compilers from conda-forge +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - sudo apt-get install build-essential python3-dev python3-setuptools \ - python3-pip - -and then:: +If you use the conda package manager (version >= 4.7), you can install the +``compilers`` meta-package from the conda-forge channel, which provides +OpenMP-enabled C/C++ compilers based on the llvm toolchain. - pip3 install numpy scipy cython +First install the macOS command line tools:: + + xcode-select --install + +It is recommended to use a dedicated `conda environment`_ to build +scikit-learn from source:: + + conda create -n sklearn-dev python numpy scipy cython joblib pytest \ + conda-forge::compilers conda-forge::llvm-openmp + conda activate sklearn-dev + make clean + pip install --verbose --editable . .. note:: - In order to build the documentation and run the example code contains in - this documentation you will need matplotlib:: + If you get any conflicting dependency error message, try commenting out + any custom conda configuration in the ``$HOME/.condarc`` file. In + particular the ``channel_priority: strict`` directive is known to cause + problems for this setup. - pip3 install matplotlib +You can check that the custom compilers are properly installed from conda +forge using the following command:: -When precompiled wheels are not avalaible for your architecture, you can -install the system versions:: + conda list compilers llvm-openmp - sudo apt-get install cython3 python3-numpy python3-scipy python3-matplotlib +The compilers meta-package will automatically set custom environment +variables:: -On Red Hat and clones (e.g. CentOS), install the dependencies using:: + echo $CC + echo $CXX + echo $CFLAGS + echo $CXXFLAGS + echo $LDFLAGS - sudo yum -y install gcc gcc-c++ python-devel numpy scipy +They point to files and folders from your ``sklearn-dev`` conda environment +(in particular in the bin/, include/ and lib/ subfolders). For instance +``-L/path/to/conda/envs/sklearn-dev/lib`` should appear in ``LDFLAGS``. -.. note:: +In the log, you should see the compiled extension being built with the clang +and clang++ compilers installed by conda with the ``-fopenmp`` command line +flag. - To use a high performance BLAS library (e.g. OpenBlas) see - `scipy installation instructions - `_. +macOS compilers from Homebrew +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Windows -------- +Another solution is to enable OpenMP support for the clang compiler shipped +by default on macOS. -To build scikit-learn on Windows you need a working C/C++ compiler in -addition to numpy, scipy and setuptools. +First install the macOS command line tools:: -The building command depends on the architecture of the Python interpreter, -32-bit or 64-bit. You can check the architecture by running the following in -``cmd`` or ``powershell`` console:: + xcode-select --install - python -c "import struct; print(struct.calcsize('P') * 8)" +Install the Homebrew_ package manager for macOS. -The above commands assume that you have the Python installation folder in your -PATH environment variable. +Install the LLVM OpenMP library:: -You will need `Build Tools for Visual Studio 2017 -`_. + brew install libomp -.. warning:: - You DO NOT need to install Visual Studio 2019. - You only need the "Build Tools for Visual Studio 2019", - under "All downloads" -> "Tools for Visual Studio 2019". +Set the following environment variables:: -For 64-bit Python, configure the build environment with:: + export CC=/usr/bin/clang + export CXX=/usr/bin/clang++ + export CPPFLAGS="$CPPFLAGS -Xpreprocessor -fopenmp" + export CFLAGS="$CFLAGS -I/usr/local/opt/libomp/include" + export CXXFLAGS="$CXXFLAGS -I/usr/local/opt/libomp/include" + export LDFLAGS="$LDFLAGS -Wl,-rpath,/usr/local/opt/libomp/lib -L/usr/local/opt/libomp/lib -lomp" - SET DISTUTILS_USE_SDK=1 - "C:\Program Files (x86)\Microsoft Visual Studio\2019\BuildTools\VC\Auxiliary\Build\vcvarsall.bat" x64 +Finally, build scikit-learn in verbose mode (to check for the presence of the +``-fopenmp`` flag in the compiler commands):: -Please be aware that the path above might be different from user to user. -The aim is to point to the "vcvarsall.bat" file. + make clean + pip install --verbose --editable . -And build scikit-learn from this environment:: +.. _compiler_linux: - python setup.py install +Linux +----- -Replace ``x64`` by ``x86`` to build for 32-bit Python. +Linux compilers from the system +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Installing scikit-learn from source without using conda requires you to have +installed the scikit-learn Python development headers and a working C/C++ +compiler with OpenMP support (typically the GCC toolchain). -Building binary packages and installers ---------------------------------------- +Install build dependencies for Debian-based operating systems, e.g. +Ubuntu:: -The ``.whl`` package and ``.exe`` installers can be built with:: + sudo apt-get install build-essential python3-dev python3-pip - pip install wheel - python setup.py bdist_wheel bdist_wininst -b doc/logos/scikit-learn-logo.bmp +then proceed as usual:: -The resulting packages are generated in the ``dist/`` folder. + pip3 install cython + pip3 install --verbose --editable . +Cython and the pre-compiled wheels for the runtime dependencies (numpy, scipy +and joblib) should automatically be installed in +``$HOME/.local/lib/pythonX.Y/site-packages``. Alternatively you can run the +above commands from a virtualenv_ or a `conda environment`_ to get full +isolation from the Python packages installed via the system packager. When +using an isolated environment, ``pip3`` should be replaced by ``pip`` in the +above commands. -Using an alternative compiler ------------------------------ +When precompiled wheels of the runtime dependencies are not avalaible for your +architecture (e.g. ARM), you can install the system versions:: + + sudo apt-get install cython3 python3-numpy python3-scipy + +On Red Hat and clones (e.g. CentOS), install the dependencies using:: -It is possible to use `MinGW `_ (a port of GCC to Windows -OS) as an alternative to MSVC for 32-bit Python. Not that extensions built with -mingw32 can be redistributed as reusable packages as they depend on GCC runtime -libraries typically not installed on end-users environment. + sudo yum -y install gcc gcc-c++ python3-devel numpy scipy -To force the use of a particular compiler, pass the ``--compiler`` flag to the -build step:: +Linux compilers from conda-forge +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Alternatively, install a recent version of the GNU C Compiler toolchain (GCC) +in the user folder using conda:: + + conda create -n sklearn-dev numpy scipy joblib cython conda-forge::compilers + conda activate sklearn-dev + pip install --verbose --editable . + +.. _compiler_freebsd: + +FreeBSD +------- - python setup.py build --compiler=my_compiler install +The clang compiler included in FreeBSD 12.0 and 11.2 base systems does not +include OpenMP support. You need to install the `openmp` library from packages +(or ports):: + + sudo pkg install openmp + +This will install header files in ``/usr/local/include`` and libs in +``/usr/local/lib``. Since these directories are not searched by default, you +can set the environment variables to these locations:: + + export CFLAGS="$CFLAGS -I/usr/local/include" + export CXXFLAGS="$CXXFLAGS -I/usr/local/include" + export LDFLAGS="$LDFLAGS -Wl,-rpath,/usr/local/lib -L/usr/local/lib -lomp" + +Finally, build the package using the standard command:: + + pip install --verbose --editable . + +For the upcoming FreeBSD 12.1 and 11.3 versions, OpenMP will be included in +the base system and these steps will not be necessary. -where ``my_compiler`` should be one of ``mingw32`` or ``msvc``. +.. _OpenMP: https://en.wikipedia.org/wiki/OpenMP +.. _Cython: https://cython.org +.. _Homebrew: https://brew.sh +.. _virtualenv: https://docs.python.org/3/tutorial/venv.html +.. _conda environment: https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html diff --git a/doc/developers/contributing.rst b/doc/developers/contributing.rst index 4b24c7089a5a8..16adf4a607d90 100644 --- a/doc/developers/contributing.rst +++ b/doc/developers/contributing.rst @@ -210,7 +210,7 @@ how to set up your git repository: 4. Install the development dependencies:: - $ pip install cython pytest flake8 + $ pip install cython pytest pytest-cov flake8 5. Install scikit-learn in editable mode:: @@ -251,7 +251,7 @@ modifying code and submitting a PR: to record your changes in Git, then push the changes to your GitHub account with:: - $ git push -u origin my-feature + $ git push -u origin my_feature 10. Follow `these `_ @@ -281,6 +281,8 @@ line http://try.github.io are excellent resources to get started with git, and understanding all of the commands shown here. +.. _pr_checklist: + Pull request checklist ---------------------- @@ -319,15 +321,15 @@ complies with the following rules before marking a PR as ``[MRG]``. The specific to the file - `pytest sklearn/linear_model` to test the whole :mod:`~sklearn.linear_model` module - - `pytest sklearn/doc/linear_model.rst` to make sure the user guide + - `pytest doc/modules/linear_model.rst` to make sure the user guide examples are correct. - `pytest sklearn/tests/test_common.py -k LogisticRegression` to run all our estimator checks (specifically for `LogisticRegression`, if that's the estimator you changed). There may be other failing tests, but they will be caught by the CI so - you don't need to run the whole test suite locally. You can read more in - :ref:`testing_coverage`. + you don't need to run the whole test suite locally. For guidelines on how + to use ``pytest`` efficiently, see the :ref:`pytest_tips`. 3. **Make sure your code is properly commented and documented**, and **make sure the documentation renders properly**. To build the documentation, please @@ -375,7 +377,7 @@ complies with the following rules before marking a PR as ``[MRG]``. The methods available in scikit-learn. 10. New features often need to be illustrated with narrative documentation in - the user guide, with small code snipets. If relevant, please also add + the user guide, with small code snippets. If relevant, please also add references in the literature, with PDF links when possible. 11. The user guide should also include expected time and space complexity @@ -435,6 +437,7 @@ message, the following actions are taken. ---------------------- ------------------- [scipy-dev] Add a Travis build with our dependencies (numpy, scipy, etc ...) development builds [ci skip] CI is skipped completely + [lint skip] Azure pipeline skips linting [doc skip] Docs are not built [doc quick] Docs built, but excludes example gallery plots [doc build] Docs built including example gallery plots @@ -535,9 +538,12 @@ Building the documentation First, make sure you have :ref:`properly installed ` the development version. +.. + packaging is not needed once setuptools starts shipping packaging>=17.0 + Building the documentation requires installing some additional packages:: - pip install sphinx sphinx-gallery numpydoc matplotlib Pillow pandas scikit-image + pip install sphinx sphinx-gallery numpydoc matplotlib Pillow pandas scikit-image packaging To build the documentation, you need to be in the ``doc`` folder:: @@ -700,14 +706,12 @@ package. The tests are functions appropriately named, located in `tests` subdirectories, that check the validity of the algorithms and the different options of the code. -The full scikit-learn tests can be run using 'make' in the root folder. -Alternatively, running 'pytest' in a folder will run all the tests of -the corresponding subpackages. +Running `pytest` in a folder will run all the tests of the corresponding +subpackages. For a more detailed `pytest` workflow, please refer to the +:ref:`pr_checklist`. We expect code coverage of new features to be at least around 90%. -For guidelines on how to use ``pytest`` efficiently, see the -:ref:`pytest_tips`. Writing matplotlib related tests -------------------------------- @@ -826,7 +830,8 @@ E.g., renaming an attribute ``labels_`` to ``classes_`` can be done as:: def labels_(self): return self.classes_ -If a parameter has to be deprecated, use ``DeprecationWarning`` appropriately. +If a parameter has to be deprecated, a ``FutureWarning`` warning +must be raised too. In the following example, k is deprecated and renamed to n_clusters:: import warnings @@ -834,7 +839,8 @@ In the following example, k is deprecated and renamed to n_clusters:: def example_function(n_clusters=8, k='deprecated'): if k != 'deprecated': warnings.warn("'k' was renamed to n_clusters in version 0.13 and " - "will be removed in 0.15.", DeprecationWarning) + "will be removed in 0.15.", + FutureWarning) n_clusters = k When the change is in a class, we validate and raise warning in ``fit``:: @@ -849,7 +855,8 @@ When the change is in a class, we validate and raise warning in ``fit``:: def fit(self, X, y): if self.k != 'deprecated': warnings.warn("'k' was renamed to n_clusters in version 0.13 and " - "will be removed in 0.15.", DeprecationWarning) + "will be removed in 0.15.", + FutureWarning) self._n_clusters = self.k else: self._n_clusters = self.n_clusters @@ -1037,53 +1044,3 @@ make this task easier and faster (in no particular order). `_) is also extremely useful to see every occurrence of a pattern (e.g. a function call or a variable) in the code base. - - -.. _plotting_api: - -Plotting API -============ - -Scikit-learn defines a simple API for creating visualizations for machine -learning. The key features of this API is to run calculations once and to have -the flexibility to adjust the visualizations after the fact. This logic is -encapsulated into a display object where the computed data is stored and -the plotting is done in a `plot` method. The display object's `__init__` -method contains only the data needed to create the visualization. The `plot` -method takes in parameters that only have to do with visualization, such as a -matplotlib axes. The `plot` method will store the matplotlib artists as -attributes allowing for style adjustments through the display object. A -`plot_*` helper function accepts parameters to do the computation and the -parameters used for plotting. After the helper function creates the display -object with the computed values, it calls the display's plot method. Note -that the `plot` method defines attributes related to matplotlib, such as the -line artist. This allows for customizations after calling the `plot` method. - -For example, the `RocCurveDisplay` defines the following methods and -attributes: - -.. code-block:: python - - class RocCurveDisplay: - def __init__(self, fpr, tpr, roc_auc, estimator_name): - ... - self.fpr = fpr - self.tpr = tpr - self.roc_auc = roc_auc - self.estimator_name = estimator_name - - def plot(self, ax=None, name=None, **kwargs): - ... - self.line_ = ... - self.ax_ = ax - self.figure_ = ax.figure_ - - def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None, - drop_intermediate=True, response_method="auto", - name=None, ax=None, **kwargs): - # do computation - viz = RocCurveDisplay(fpr, tpr, roc_auc, - estimator.__class__.__name__) - return viz.plot(ax=ax, name=name, **kwargs) - -Read more in the :ref:`User Guide `. diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index 7fd76b23f4f28..ead6286d98083 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -453,7 +453,7 @@ this can be achieved with:: return self.classes_[np.argmax(D, axis=1)] In linear models, coefficients are stored in an array called ``coef_``, and the -independent term is stored in ``intercept_``. ``sklearn.linear_model.base`` +independent term is stored in ``intercept_``. ``sklearn.linear_model._base`` contains a few base classes and mixins that implement common linear model patterns. diff --git a/doc/developers/index.rst b/doc/developers/index.rst index f1e4816855180..e64adf5ac73a9 100644 --- a/doc/developers/index.rst +++ b/doc/developers/index.rst @@ -1,3 +1,7 @@ +.. Places global toc into the sidebar + +:globalsidebartoc: True + .. _developers_guide: ================= @@ -16,3 +20,4 @@ Developer's Guide performance advanced_installation maintainer + plotting diff --git a/doc/developers/maintainer.rst b/doc/developers/maintainer.rst index e91f01999b12e..66d5250af1644 100644 --- a/doc/developers/maintainer.rst +++ b/doc/developers/maintainer.rst @@ -62,7 +62,7 @@ Making a release 2. On the branch for releasing, update the version number in sklearn/__init__.py, the ``__version__`` variable by removing ``dev*`` only when ready to release. - On master, increment the verson in the same place (when branching for + On master, increment the version in the same place (when branching for release). 3. Create the tag and push it:: diff --git a/doc/developers/performance.rst b/doc/developers/performance.rst index 743835d41375c..1be0dc9b575e1 100644 --- a/doc/developers/performance.rst +++ b/doc/developers/performance.rst @@ -200,7 +200,8 @@ Now restart IPython and let us use this new toy:: In [1]: from sklearn.datasets import load_digits - In [2]: from sklearn.decomposition.nmf import _nls_subproblem, NMF + In [2]: from sklearn.decomposition import NMF + ... : from sklearn.decomposition._nmf import _nls_subproblem In [3]: X, _ = load_digits(return_X_y=True) @@ -331,16 +332,16 @@ memory alignment, direct blas calls... Using OpenMP ------------ -Since scikit-learn can be built without OpenMP support, it's necessary to +Since scikit-learn can be built without OpenMP, it's necessary to protect each direct call to OpenMP. This can be done using the following syntax:: # importing OpenMP - IF SKLEARN_OPENMP_SUPPORTED: + IF SKLEARN_OPENMP_PARALLELISM_ENABLED: cimport openmp # calling OpenMP - IF SKLEARN_OPENMP_SUPPORTED: + IF SKLEARN_OPENMP_PARALLELISM_ENABLED: max_threads = openmp.omp_get_max_threads() ELSE: max_threads = 1 diff --git a/doc/developers/plotting.rst b/doc/developers/plotting.rst new file mode 100644 index 0000000000000..98af195b56453 --- /dev/null +++ b/doc/developers/plotting.rst @@ -0,0 +1,90 @@ +.. _plotting_api: + +================================ +Developing with the Plotting API +================================ + +Scikit-learn defines a simple API for creating visualizations for machine +learning. The key features of this API is to run calculations once and to have +the flexibility to adjust the visualizations after the fact. This section is +intended for developers who wish to develop or maintain plotting tools. For +usage, users should refer to the :ref`User Guide `. + +Plotting API Overview +--------------------- + +This logic is encapsulated into a display object where the computed data is +stored and the plotting is done in a `plot` method. The display object's +`__init__` method contains only the data needed to create the visualization. +The `plot` method takes in parameters that only have to do with visualization, +such as a matplotlib axes. The `plot` method will store the matplotlib artists +as attributes allowing for style adjustments through the display object. A +`plot_*` helper function accepts parameters to do the computation and the +parameters used for plotting. After the helper function creates the display +object with the computed values, it calls the display's plot method. Note that +the `plot` method defines attributes related to matplotlib, such as the line +artist. This allows for customizations after calling the `plot` method. + +For example, the `RocCurveDisplay` defines the following methods and +attributes:: + + class RocCurveDisplay: + def __init__(self, fpr, tpr, roc_auc, estimator_name): + ... + self.fpr = fpr + self.tpr = tpr + self.roc_auc = roc_auc + self.estimator_name = estimator_name + + def plot(self, ax=None, name=None, **kwargs): + ... + self.line_ = ... + self.ax_ = ax + self.figure_ = ax.figure_ + + def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None, + drop_intermediate=True, response_method="auto", + name=None, ax=None, **kwargs): + # do computation + viz = RocCurveDisplay(fpr, tpr, roc_auc, + estimator.__class__.__name__) + return viz.plot(ax=ax, name=name, **kwargs) + +Read more in :ref:`sphx_glr_auto_examples_plot_roc_curve_visualization_api.py` +and the :ref:`User Guide `. + +Plotting with Multiple Axes +--------------------------- + +Some of the plotting tools like +:func:`~sklearn.inspection.plot_partial_dependence` and +:class:`~sklearn.inspection.PartialDependenceDisplay` support plottong on +multiple axes. Two different scenarios are supported: + +1. If a list of axes is passed in, `plot` will check if the number of axes is +consistent with the number of axes it expects and then draws on those axes. 2. +If a single axes is passed in, that axes defines a space for multiple axes to +be placed. In this case, we suggest using matplotlib's +`~matplotlib.gridspec.GridSpecFromSubplotSpec` to split up the space:: + + import matplotlib.pyplot as plt + from matplotlib.gridspec import GridSpecFromSubplotSpec + + fig, ax = plt.subplots() + gs = GridSpecFromSubplotSpec(2, 2, subplot_spec=ax.get_subplotspec()) + + ax_top_left = fig.add_subplot(gs[0, 0]) + ax_top_right = fig.add_subplot(gs[0, 1]) + ax_bottom = fig.add_subplot(gs[1, :]) + +By default, the `ax` keyword in `plot` is `None`. In this case, the single +axes is created and the gridspec api is used to create the regions to plot in. + +See for example, :func:`~sklearn.inspection.plot_partial_dependence` which +plots multiple lines and contours using this API. The axes defining the +bounding box is saved in a `bounding_ax_` attribute. The individual axes +created are stored in an `axes_` ndarray, corresponding to the axes position on +the grid. Positions that are not used are set to `None`. Furthermore, the +matplotlib Artists are stored in `lines_` and `contours_` where the key is the +position on the grid. When a list of axes is passed in, the `axes_`, `lines_`, +and `contours_` is a 1d ndarray corresponding to the list of axes passed in. diff --git a/doc/developers/tips.rst b/doc/developers/tips.rst index 76b655274ef28..ed049285c36f3 100644 --- a/doc/developers/tips.rst +++ b/doc/developers/tips.rst @@ -102,10 +102,10 @@ Other `pytest` options that may become useful include: statements - ``--tb=short`` or ``--tb=line`` to control the length of the logs -Since our continuous integration tests will error if ``DeprecationWarning`` -or ``FutureWarning`` aren't properly caught, it is also recommended to run -``pytest`` along with the ``-Werror::DeprecationWarning`` and -``-Werror::FutureWarning`` flags. +Since our continuous integration tests will error if +``FutureWarning`` isn't properly caught, +it is also recommended to run ``pytest`` along with the +``-Werror::FutureWarning`` flag. .. _saved_replies: @@ -181,10 +181,10 @@ Issue/Comment: Linking to comments Please use links to comments, which make it a lot easier to see what you are referring to, rather than just linking to the issue. See [this](https://stackoverflow.com/questions/25163598/how-do-i-reference-a-specific-issue-comment-on-github) for more details. -PR-NEW: Better description +PR-NEW: Better description and title :: - Thanks for the pull request! Please make the title of the PR descriptive so that we can easily recall the issue it is resolving. You should state what issue (or PR) it fixes/resolves in the description (see [here](http://scikit-learn.org/dev/developers/contributing.html#contributing-pull-requests)). + Thanks for the pull request! Please make the title of the PR more descriptive. The title will become the commit message when this is merged. You should state what issue (or PR) it fixes/resolves in the description using the syntax described [here](http://scikit-learn.org/dev/developers/contributing.html#contributing-pull-requests). PR-NEW: Fix # :: diff --git a/doc/developers/utilities.rst b/doc/developers/utilities.rst index 83fd044b99df3..3d4995d8f8100 100644 --- a/doc/developers/utilities.rst +++ b/doc/developers/utilities.rst @@ -176,13 +176,7 @@ Graph Routines Testing Functions ================= -- :func:`testing.assert_in`, :func:`testing.assert_not_in`: Assertions for - container membership. Designed for forward compatibility with Nose 1.0. - -- :func:`testing.assert_raise_message`: Assertions for checking the - error raise message. - -- :func:`testing.all_estimators` : returns a list of all estimators in +- :func:`all_estimators` : returns a list of all estimators in scikit-learn to test for consistent behavior and interfaces. Multiclass and multilabel utility function diff --git a/doc/documentation.rst b/doc/documentation.rst deleted file mode 100644 index a55fbe37258ae..0000000000000 --- a/doc/documentation.rst +++ /dev/null @@ -1,117 +0,0 @@ -:orphan: - -.. raw:: html - -
- -Documentation of scikit-learn |version| -======================================= - -.. raw:: html - - - -
-
-

Quick Start

-
A very short introduction into machine learning - problems and how to solve them using scikit-learn. - Presents basic concepts and conventions. -
-
-
-

User Guide

-
The main documentation. This contains an - in-depth description of all algorithms and how - to apply them. -
-
-
- -

Other Versions

-
    - - -
  • All available versions
  • -
  • PDF documentation
  • -
- -
- -
- - -
-
-

Tutorials

-
Useful tutorials for developing a feel - for some of scikit-learn's applications in the - machine learning field. -
-
-
-

Glossary

-
The definitive description of key concepts - and API elements for using scikit-learn and developing compatible tools. -
-
-
-

API

-
The exact API of all functions and classes, as given by the docstrings. - The API documents expected types and allowed features for all functions, - and all parameters available for the algorithms. -
-
- -
- -
-
-

Development

-
Information on how to contribute. This also - contains useful information for advanced users, for example - how to build their own estimators. -
-
-
-

FAQ

-
Frequently asked questions about the project and contributing. -
-
-
-

Additional Resources

-
Talks given, slide-sets and other information relevant to scikit-learn. -
-
- -
- -
-
-

Flow Chart

-
A graphical overview of basic areas of machine - learning, and guidance which kind of algorithms - to use in a given situation. -
-
-
-

Related packages

-
Other machine learning packages for Python and - related projects. Also algorithms that are slightly out of - scope or not well established enough for scikit-learn. -
-
-
-

Roadmap

-
Roadmap of the project. -
-
-
-
-
-

About us

-
About the scikit-learn project. -
-
- -
diff --git a/doc/faq.rst b/doc/faq.rst index 1ff092a1ee724..6972d79fd5513 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -299,23 +299,20 @@ documentation `. Why is there no support for deep or reinforcement learning / Will there be support for deep or reinforcement learning in scikit-learn? @@ -388,3 +385,23 @@ efficient to process for most operations. Extensive work would also be needed to support Pandas categorical types. Restricting input to homogeneous types therefore reduces maintenance cost and encourages usage of efficient data structures. + +Do you plan to implement transform for target y in a pipeline? +---------------------------------------------------------------------------- +Currently transform only works for features X in a pipeline. +There's a long-standing discussion about +not being able to transform y in a pipeline. +Follow on github issue +`#4143 `_. +Meanwhile check out +:class:`sklearn.compose.TransformedTargetRegressor`, +`pipegraph `_, +`imbalanced-learn `_. +Note that Scikit-learn solved for the case where y +has an invertible transformation applied before training +and inverted after prediction. Scikit-learn intends to solve for +use cases where y should be transformed at training time +and not at test time, for resampling and similar uses, +like at imbalanced learn. +In general, these use cases can be solved +with a custom meta estimator rather than a Pipeline diff --git a/doc/getting_started.rst b/doc/getting_started.rst new file mode 100644 index 0000000000000..ba18b92e40983 --- /dev/null +++ b/doc/getting_started.rst @@ -0,0 +1,231 @@ +Getting Started +=============== + +The purpose of this guide is to illustrate some of the main features that +``scikit-learn`` provides. It assumes a very basic working knowledge of +machine learning practices (model fitting, predicting, cross-validation, +etc.). Please refer to our :ref:`installation instructions +` for installing ``scikit-learn``. + +``Scikit-learn`` is an open source machine learning library that supports +supervised and unsupervised learning. It also provides various tools for +model fitting, data preprocessing, model selection and evaluation, and many +other utilities. + +Fitting and predicting: estimator basics +---------------------------------------- + +``Scikit-learn`` provides dozens of built-in machine learning algorithms and +models, called :term:`estimators`. Each estimator can be fitted to some data +using its :term:`fit` method. + +Here is a simple example where we fit a +:class:`~sklearn.ensemble.RandomForestClassifier` to some very basic data:: + + >>> from sklearn.ensemble import RandomForestClassifier + >>> clf = RandomForestClassifier(random_state=0) + >>> X = [[ 1, 2, 3], # 2 samples, 3 features + ... [11, 12, 13]] + >>> y = [0, 1] # classes of each sample + >>> clf.fit(X, y) + RandomForestClassifier(random_state=0) + +The :term:`fit` method generally accepts 2 inputs: + +- The samples matrix (or design matrix) :term:`X`. The size of ``X`` + is typically ``(n_samples, n_features)``, which means that samples are + represented as rows and features are represented as columns. +- The target values :term:`y` which are real numbers for regression tasks, or + integers for classification (or any other discrete set of values). For + unsupervized learning tasks, ``y`` does not need to be specified. ``y`` is + usually 1d array where the ``i`` th entry corresponds to the target of the + ``i`` th sample (row) of ``X``. + +Both ``X`` and ``y`` are usually expected to be numpy arrays or equivalent +:term:`array-like` data types, though some estimators work with other +formats such as sparse matrices. + +Once the estimator is fitted, it can be used for predicting target values of +new data. You don't need to re-train the estimator:: + + >>> clf.predict(X) # predict classes of the training data + array([0, 1]) + >>> clf.predict([[4, 5, 6], [14, 15, 16]]) # predict classes of new data + array([0, 1]) + +Transformers and pre-processors +------------------------------- + +Machine learning workflows are often composed of different parts. A typical +pipeline consists of a pre-processing step that transforms or imputes the +data, and a final predictor that predicts target values. + +In ``scikit-learn``, pre-processors and transformers follow the same API as +the estimator objects (they actually all inherit from the same +``BaseEstimator`` class). The transformer objects don't have a +:term:`predict` method but rather a :term:`transform` method that outputs a +newly transformed sample matrix ``X``:: + + >>> from sklearn.preprocessing import StandardScaler + >>> X = [[0, 15], + ... [1, -10]] + >>> StandardScaler().fit(X).transform(X) + array([[-1., 1.], + [ 1., -1.]]) + +Sometimes, you want to apply different transformations to different features: +the :ref:`ColumnTransformer` is designed for these +use-cases. + +Pipelines: chaining pre-processors and estimators +-------------------------------------------------- + +Transformers and estimators (predictors) can be combined together into a +single unifying object: a :class:`~sklearn.pipeline.Pipeline`. The pipeline +offers the same API as a regular estimator: it can be fitted and used for +prediction with ``fit`` and ``predict``. As we will see later, using a +pipeline will also prevent you from data leakage, i.e. disclosing some +testing data in your training data. + +In the following example, we :ref:`load the Iris dataset `, split it +into train and test sets, and compute the accuracy score of a pipeline on +the test data:: + + >>> from sklearn.preprocessing import StandardScaler + >>> from sklearn.linear_model import LogisticRegression + >>> from sklearn.pipeline import make_pipeline + >>> from sklearn.datasets import load_iris + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.metrics import accuracy_score + ... + >>> # create a pipeline object + >>> pipe = make_pipeline( + ... StandardScaler(), + ... LogisticRegression(random_state=0) + ... ) + ... + >>> # load the iris dataset and split it into train and test sets + >>> X, y = load_iris(return_X_y=True) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + ... + >>> # fit the whole pipeline + >>> pipe.fit(X_train, y_train) + Pipeline(steps=[('standardscaler', StandardScaler()), + ('logisticregression', LogisticRegression(random_state=0))]) + >>> # we can now use it like any other estimator + >>> accuracy_score(pipe.predict(X_test), y_test) + 0.97... + +Model evaluation +---------------- + +Fitting a model to some data does not entail that it will predict well on +unseen data. This needs to be directly evaluated. We have just seen the +:func:`~sklearn.model_selection.train_test_split` helper that splits a +dataset into train and test sets, but ``scikit-learn`` provides many other +tools for model evaluation, in particular for :ref:`cross-validation +`. + +We here briefly show how to perform a 5-fold cross-validation procedure, +using the :func:`~sklearn.model_selection.cross_validate` helper. Note that +it is also possible to manually iterate over the folds, use different +data splitting strategies, and use custom scoring functions. Please refer to +our :ref:`User Guide ` for more details:: + + >>> from sklearn.datasets import make_regression + >>> from sklearn.linear_model import LinearRegression + >>> from sklearn.model_selection import cross_validate + ... + >>> X, y = make_regression(n_samples=1000, random_state=0) + >>> lr = LinearRegression() + ... + >>> result = cross_validate(lr, X, y) # defaults to 5-fold CV + >>> result['test_score'] # r_squared score is high because dataset is easy + array([1., 1., 1., 1., 1.]) + +Automatic parameter searches +---------------------------- + +All estimators have parameters (often called hyper-parameters in the +literature) that can be tuned. The generalization power of an estimator +often critically depends on a few parameters. For example a +:class:`~sklearn.ensemble.RandomForestRegressor` has a ``n_estimators`` +parameter that determines the number of trees in the forest, and a +``max_depth`` parameter that determines the maximum depth of each tree. +Quite often, it is not clear what the exact values of these parameters +should be since they depend on the data at hand. + +``Scikit-learn`` provides tools to automatically find the best parameter +combinations (via cross-validation). In the following example, we randomly +search over the parameter space of a random forest with a +:class:`~sklearn.model_selection.RandomizedSearchCV` object. When the search +is over, the :class:`~sklearn.model_selection.RandomizedSearchCV` behaves as +a :class:`~sklearn.ensemble.RandomForestRegressor` that has been fitted with +the best set of parameters. Read more in the :ref:`User Guide +`:: + + >>> from sklearn.datasets import fetch_california_housing + >>> from sklearn.ensemble import RandomForestRegressor + >>> from sklearn.model_selection import RandomizedSearchCV + >>> from sklearn.model_selection import train_test_split + >>> from scipy.stats import randint + ... + >>> X, y = fetch_california_housing(return_X_y=True) + >>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + ... + >>> # define the parameter space that will be searched over + >>> param_distributions = {'n_estimators': randint(1, 5), + ... 'max_depth': randint(5, 10)} + ... + >>> # now create a searchCV object and fit it to the data + >>> search = RandomizedSearchCV(estimator=RandomForestRegressor(random_state=0), + ... n_iter=5, + ... param_distributions=param_distributions, + ... random_state=0) + >>> search.fit(X_train, y_train) + RandomizedSearchCV(estimator=RandomForestRegressor(random_state=0), n_iter=5, + param_distributions={'max_depth': ..., + 'n_estimators': ...}, + random_state=0) + >>> search.best_params_ + {'max_depth': 9, 'n_estimators': 4} + + >>> # the search object now acts like a normal random forest estimator + >>> # with max_depth=9 and n_estimators=4 + >>> search.score(X_test, y_test) + 0.73... + +.. note:: + + In practice, you almost always want to :ref:`search over a pipeline + `, instead of a single estimator. One of the main + reasons is that if you apply a pre-processing step to the whole dataset + without using a pipeline, and then perform any kind of cross-validation, + you would be breaking the fundamental assumption of independence between + training and testing data. Indeed, since you pre-processed the data + using the whole dataset, some information about the test sets are + available to the train sets. This will lead to over-estimating the + generalization power of the estimator (you can read more in this `kaggle + post `_). + + Using a pipeline for cross-validation and searching will largely keep + you from this common pitfall. + + +Next steps +---------- + +We have briefly covered estimator fitting and predicting, pre-processing +steps, pipelines, cross-validation tools and automatic hyper-parameter +searches. This guide should give you an overview of some of the main +features of the library, but there is much more to ``scikit-learn``! + +Please refer to our :ref:`user_guide` for details on all the tools that we +provide. You can also find an exhaustive list of the public API in the +:ref:`api_ref`. + +You can also look at our numerous :ref:`examples ` that +illustrate the use of ``scikit-learn`` in many different contexts. + +The :ref:`tutorials ` also contain additional learning +resources. diff --git a/doc/glossary.rst b/doc/glossary.rst index 99f512cc49acc..e259fa69745bc 100644 --- a/doc/glossary.rst +++ b/doc/glossary.rst @@ -697,6 +697,7 @@ General Concepts to :term:`unlabeled` samples in semi-supervised classification. sparse matrix + sparse graph A representation of two-dimensional numeric data that is more memory efficient the corresponding dense numpy array where almost all elements are zero. We use the :mod:`scipy.sparse` framework, which provides @@ -1160,7 +1161,7 @@ Methods TODO: `This gist `_ - higlights the use of the different formats for multilabel. + highlights the use of the different formats for multilabel. multioutput classification A list of 2d arrays, corresponding to each multiclass decision function. @@ -1507,45 +1508,29 @@ functions or non-estimator constructors. early. ``n_jobs`` - This is used to specify how many concurrent processes/threads should be - used for parallelized routines. Scikit-learn uses one processor for - its processing by default, although it also makes use of NumPy, which - may be configured to use a threaded numerical processor library (like - MKL; see :ref:`FAQ `). - - ``n_jobs`` is an int, specifying the maximum number of concurrently - running jobs. If set to -1, all CPUs are used. If 1 is given, no - joblib level parallelism is used at all, which is useful for - debugging. Even with ``n_jobs = 1``, parallelism may occur due to - numerical processing libraries (see :ref:`FAQ `). - For n_jobs below -1, (n_cpus + 1 + n_jobs) are used. Thus for - ``n_jobs = -2``, all CPUs but one are used. - - ``n_jobs=None`` means *unset*; it will generally be interpreted as - ``n_jobs=1``, unless the current :class:`joblib.Parallel` backend - context specifies otherwise. - - The use of ``n_jobs``-based parallelism in estimators varies: - - * Most often parallelism happens in :term:`fitting `, but - sometimes parallelism happens in prediction (e.g. in random forests). - * Some parallelism uses a multi-threading backend by default, some - a multi-processing backend. It is possible to override the default - backend by using :func:`sklearn.utils.parallel_backend`. - * Whether parallel processing is helpful at improving runtime depends - on many factors, and it's usually a good idea to experiment rather - than assuming that increasing the number of jobs is always a good - thing. *It can be highly detrimental to performance to run multiple - copies of some estimators or functions in parallel.* - - Nested uses of ``n_jobs``-based parallelism with the same backend will - result in an exception. - So ``GridSearchCV(OneVsRestClassifier(SVC(), n_jobs=2), n_jobs=2)`` - won't work. - - When ``n_jobs`` is not 1, the estimator being parallelized must be - picklable. This means, for instance, that lambdas cannot be used - as estimator parameters. + This parameter is used to specify how many concurrent processes or + threads should be used for routines that are parallelized with + :term:`joblib`. + + ``n_jobs`` is an integer, specifying the maximum number of concurrently + running workers. If 1 is given, no joblib parallelism is used at all, + which is useful for debugging. If set to -1, all CPUs are used. For + ``n_jobs`` below -1, (n_cpus + 1 + n_jobs) are used. For example with + ``n_jobs=-2``, all CPUs but one are used. + + ``n_jobs`` is ``None`` by default, which means *unset*; it will + generally be interpreted as ``n_jobs=1``, unless the current + :class:`joblib.Parallel` backend context specifies otherwise. + + For more details on the use of ``joblib`` and its interactions with + scikit-learn, please refer to our :ref:`parallelism notes + `. + + ``pos_label`` + Value with which positive labels must be encoded in binary + classification problems in which the positive class is not assumed. + This value is typically required to compute asymmetric evaluation + metrics such as precision and recall. ``random_state`` Whenever randomization is part of a Scikit-learn algorithm, a diff --git a/doc/governance.rst b/doc/governance.rst index b8f3bda4328ea..82d69cc046345 100644 --- a/doc/governance.rst +++ b/doc/governance.rst @@ -30,7 +30,7 @@ Core developers --------------- Core developers are community members who have shown that they are dedicated to the continued development of the project through ongoing engagement with the -community. They have shown they can be trusted to maintain Scikit-learn with +community. They have shown they can be trusted to maintain scikit-learn with care. Being a core developer allows contributors to more easily carry on with their project related activities by giving them direct access to the project’s repository and is represented as being an organization member on the diff --git a/doc/images/anaconda-small.png b/doc/images/anaconda-small.png new file mode 100644 index 0000000000000000000000000000000000000000..ccb8bb8b707deca78f49e2423dbd380b48ba4052 GIT binary patch literal 11313 zcmV-1EY8!3P) zaB^>EX>4U6ba`-PAZ2)IW&i+q+O3;eawNHOMgOr1Edlee9E@h#ftIgx5y9fntS5Vt zs=AA0W@H2acfW%JH2c5)^O*nOPl>^rm`cqpXUm`1V)LDEs(t=EU!9Hj-}@(iyyw2Z zZhU;-@VpfG8lE5Xey{Jmo<2WN-s|1?{<vjI+6|Yx3--r83 z2*0eK`LlGtEa1nw{#-vh-z&oBay&Qsm{$I=PVeh3{l0#8mA;P1`S%k4`G)@N<%a$G z=gaz?v+6nfnO%?Buw40_)W;O>cU;I#Vz?j6{4M;S_`ck~jlb1)JBiw2s>2Y6>6-Jb z>~Ya8*WGdZew}VI#OSv#eD!|%a9^vT_~z#gO4MI>d@;mlg-T2&kqf-+|BS`m=Qcf$ zhgF`u99MeB%@oCNf6dPq{-RK;pxpP{~`m7b>^2#rUp~&f9t|B1ryla~7`a1J` zu|NMPuz?7syUdjdb~~P1jO0FZE8I8-n+%sm`_zjo>%9R>#IqBFae)Zf>_T$c+2Vb1 zj&-cWr!#UNq8~)SC6oMOls<-#2~PETHlHc(Uh9*8UItzWL<%_+LPG%&tQb?w)LSKm zx=A6$lu}M5)znhYA;+9@&Ls<{dI=?#RB|b$mR5QVHP%#fEw$EGd-E*-W68AKN~^84 z-g#(`oqKhz@4PVl2qTU(@+hN@Hu@xeW}Io}S!SJW_T^VtK!90!l~q?;eY;I6?YPs< zyX?B#?uT4E;lz_pKIPQYPX80NH>%&G_A_$-I%@8Xn!kb4b>%Z^yj;r1Eu7#aDP}~> zM@PhqB0xfW#mrV0qgUhNl`k0Bed+)*6)5oXGy^rpq7e4QykKKD%+ro3pA9wjN6Z7wm zq|_D04yCL?Hn#GtX&ELL2@qnnMwqrzwF4vE@H z)fGP=vEre}x=YxD^;$Dif`M_=QsQpj4F<_8!Bbgr_K12haLt~U^y)kIRjK3Di75r{ zLZy@vpb!f4>|yYfKH@5=fe<}2#d>-vXZjA*mXs(;8!9oVH7}*`JfSDcy0}$bW9ehr z1Ivi6DF ztW>(!oQOlq=gy{rnhPl9hCnH)%>BWxQ_`ws_jO!4iTN#%iwK12PBGx3^IEsyBiWq> zppKw)?W9>ij8p+2x7QXzgRsCF@dW%pJU~ih3ucCCMrTZs8nGp$U4W%}Y)X?K>}U&d z)BxF02~_jh9!ZKGn~gbyN_2??8RM(jel(ylmyQFN#Chxl#FBR_D@~{WC~8vRmWh9g zWh4^uh@%8TMD60ry#Jm}j%=0(s3#yyD15o}=ao?j^b{4(Fx=ih6~iM+h5e#^?_lkC zK6LM)J6wGHxWo5xn$Lm$xeL)7$92%e<=TQiM#&Wjsh$B#k!Fsh?S#4PunzggV(UJ2 z!C}LZs$>Y|&cNF$BxPe1-gd>j9M}pq$sp?;-3A8>6(G0>7e+pD_x}RB=C2HL$WT1W zv($))jlcy+GD!;r693>gHA%*+rh}<-Je8QnHx_3-80`ySzX-bLDWU-|DJUxCdJw0^+a{2i@FJoCZk_u9Y1IukG@ zhYbvezWS3*5khg~z-l7r)4*K`kzPd2+s#$F3KdNQO5}!0V0C;pBuFHE8W2V4wkcdh zMwERD@AL^!LigyIHd_U}JJ+(Bz!B_a(L;9S&C(g@F1EHw6zYQX%X`ZXp-3vB1)c_B z1~OL_k)di9a)Nud9mRyXSc(^{G!Bkw+hd8qPO(V!W9K!98hj*=9UQU?{T&c^_}X5w zKog~oYO5>(h%m5O3_3v)r+Z`_*y_+w2aR~*1U9N%q$mgGoI%j5g+-gJ`c&ta_oDH#)89>TtxnY z@2en|?wj{5^Snm_Kkup`^p51mSJJ26E5UmluD&a_9G^Q=izqf=;ljs2A1(k`SASgl z^~&eM=Q7~E;EO2C2ukCr?KF%KL?HmkvBM?E98X+4w_pgAG!=;&1QbYROf;$xHs_|w z$JJc|ngo~zl@-g?rbQV6qf1QjC&5=*Xb2GrMCDL?h%yqJ-0|db^@T#i(dHMkQ5bRL z#fNf*1`!Q;PEwko;6BPhq)$1QnRg$zJw{q;v~LzcirqU#bVPLI&rVOaunIP1{yRz> z2vYhUI@G&SRLs=@5TD3M>cWHCOaNd@D5p&BZO8KYN+zF9KG#ID3?@qn&Wf#V__FFD zKSU88fjlV^$Se}M=!~P#`vUNKugy}Tc4qRdIkJSHu>dPk!wY4t>Of7|!8&K3>XLX= zTLx%rD7FT!$uq*5_)F|ZXoPPfBD{o~)W~J9w8zyTV#|SM6&mveZB4wB!Iz&zAj*?s z(_tSh`F~e60EMs>RrM?@^Ui-MPV-rv%Gq0=->dU`ZK^b%!qgbQ6sGw{WxjREZQshY z_DgMkFU_|wRhQ;lZGNfDKh^iAHh(MJ&r&phsm=eX+TNo4EMfD#hQF6Et3bMt05+$D zHiH1*HWXq-^F*gYY-u&HdnAKaSUwCmllYU}D7YCjlW%;nB5;ywmMq$9!*L^DHw%or zD9Q6*TU29M2F}D>_%s6K#o zLZ|w{U?)Tb57p)*KRn5ayurj`^BH#a=rh23b_6vM7e-8_Tt@<~i7Jf8OmZONYN;y{ z$}*DyHp^(r+>Hoc1>1w%d*}^|x~UEd#peUaP>UN?F>){z8YPUl!T3(@SLC2E9gJv$ zi&L5p2?@xx*bu+*3goAtP$}^N?ntd+65I7)|8y^8`FZ@=hl3#<)#w6v)h4+us-U(5 zYJ_E!Qz6X(duYem^*jO&{D~Uh2))c&zJQPU6@BE|QsEgaMot$(1~= zGBz@6WU`uEOvV5%%aJ4AwIN zLfLGm?jjDP;3`&t0I4$^9c%ASg2{F)AUcdTQzrNj>2L$tO2V^P9%@K!_kHZu{WkOjE{7sTnQvL`jD5kP!Xv^!#X(cgtYUKt0H_tsZ97yZuX zwfEAFOh-bfSGLqKm0xuj3gdQ|QXdwt12CeaCm&Q4k=`xwFBdGoDzpk=L;I4b;2NPV zPiQxZlF1|Pop_QEOMDY3=aR7R*Ha+lJ;iJYNK zXUWh<(~`;oyHJ&kh>xv~6pJm1#~lWf0&qheBjil-rsS)-PVgf}X7?Lw!<<+=D6A&! zs<2o{nx}ZLRr+!w!zLto!HJ-9l^B=&{m?XaI?NSgwv@e78_*pnJ>TXOq zM}ATxy$Dwkf5BsGyP2gDxbSp~D`UoYB3*c(8g_U2@xuY0Kssr6sRuBmG&mAXt;c)n zIt&9?2K;0ZN?6I$A~m*f<=jtNCMR`B1xihEv8IcW*3+xiGVp41-Z1S)4<0*!Sy;SbwTUcdJQF1kw`o zZSKONdMUw!<%R9F)kIYcirLFjW?};-Ua9YR95YU#K%TIkfMJjq>6rhar89xyjwjC8 zXSrbkTw4t41-NDLC%9)Qd&orBSQ!Gi!Hh}&!C(>IYC{o!C>;d>lVU)o;-&sIMNR3< zrqonl!u-B{s&+l=SZ|x73QuKkp7~z+DswPocb=VqMfERac?S;~iVcm$w`v*U=bcw? zxULFvs;lhf!UGp*Yl+XpIz%Nwri}_TINmTRkf!bewQsROLCui>h&LvxF6Ri1$)N*0 z4nTco6c=&da-9*uAu9@w5fukQBJDf_iB>I+jm(C7a0OFvTeQs|Ae4^$3Od8FuBRn4 zL~AR0!@6ckqtR~%wGTAe#0e)vQq5Ep)HP>N|FMmVEOL2GY@^dKAq4JN1kX$@1=11M zGSqvR>Qf+6Q4XxKOy0RLu7*9N8L;E%Sdxf20-R`z8im@^?fx1HCxzeyF5M(V~1WEI>H`78JO`9R3<$92w3L;|4YKhK9i^ zF9yy4paBvWgw$RfdEiaRR!2Ip=^DShw#W$tLlSRo5AE0*auIbnbbNx506Q=(D;9C! z>iyxFqtWpR>L2^9D$INuwRv}AeIbMmZYoo9V(`Gb5pgPku1!(c6fG``MDZrO8aDJ( z08$(}p3Zr^sKSr?1W6Vol{ebAwvVfv)`U+802&ZdjYy*cQ^VhIYc1Wqjh{U7pxEj=vo6$Lnx+Db> z`KL2kAT?na*gK4$4}8&VVzQ|jGEN8f5GT}QN9RpB(@x5PvXDpw6cPXg4}gafvI}0t z*u%9IUx@*0+6^(GW$X=g9B(^#X8Y{2T|OcvR4Wl<_DefUTIvqG;C7E1*5EN7TkvfV z`p|>6UP-I$wD+A@&j@-pwNj~gaD|*iHZvKW_Yk3LX|$KFeYqkK)>%Jka*33_FQo;EtO$VvMFROK~2E|e8QhgvJth`iw zQ!JKsPC(MzkUH9uvXCU+7+ZrlW7WY?uWbNI2!7%* zI>%_9MnID89KIx{HW!XHNd0R-Uz6_5IlV(X$!VQU!eStt$ZYUkjT78LN8GWPj;SIU z1X_f4D(VOo!jFh-l5{)jm=ASvr zDEBK|k(K`kGC68$0~5gVmcnL;6YjZ(Jf&`6$|~(vSpsj>R}vDMBP%!X;YF?2derbt zNb=#U=$G(RuwR`x!J;}$1a~SOq7De{iI1QrvFqwAvxq`J@S;jTjqI@V3uIXdp@trD zIueEOOu1zilI&Ozf7Icr8KI#O%msX&iw%$PGj&v4ReL}U{{UHCZPW)S5_D1kbq9VX zTs#zqAP&^#wksIXm(Dg|vl#B{e%J;Q2_RMR?A_t zU>^vIwGh3s!~_fmNID$_Xj(E+CzldWyOA6^PRazKcvVD~lSSH?OYurV;ID7im`cHT z%i-;u_X*Ve+~^UQ+Lom=EgjKJ5Mi~YLvFX`rx{o!NElB%)YvVQw~P9Tlt79i^7707SaLuel#lZiS|hrp2zsWmK`f;5opvlc(~NVf%p#$r6>jI zgM&4bYmGuRklUa|4;5K^AQZy5&3)U}9^}TtKnu}JB!u@?YM_uu?~zadqDPGQ88S7P z#f-LQVh+aFxf^hs)h}AngCh!VAt4*C3?Z2$iFFPG=ZAw3UFrm%kywSL2Qq%@Sb!u$ zvFbv)4dJhnI)Q>>di9ptRBNA0GJw-4}J{sfz-l8fJCJ5RLr#J5lJ&< z##-y}5uP^63iri;@p8f$Ayn{2D_ssHpr%t&gRMJKWO|*C=uaa7OmGPI^5h7@c>Ou> zuQK&SrZD)-ik~XfJSs@U=T-zC&02fDZR}bh&rU5erd$={j6EZuiVrQFiaA3mWx4bXP}qN+H7c{C1*W&uG8=U z$U)X^9Ra3F8yW{&6j-O^n1GH*E3PgHUf0C;7qKFU4XjP4-ETm_GIPkLZgQX1#`{^u zg2p6`u`BW`yJ-Uv9g6eArLJgjFjCMXx{xPiMS_a|2!u`p0ZX9*GY!A3_79T!3@c|D z(Dc%`y`l#|TN09S@{rhduv4oCv{enVf@!kIq%ZtEGOW4CRQ)0id`DeG*jVM#onVC2tbnr|(opl;Y$cHRP>sq2YDV(NFT7*ArR^SLMtbiQ7wpet) zpCTUdK+l+X@4x}D`1Zg$cp#e9nrhdFHNdT$GIXrjGOE=B1%r79HcXNN_uNC>3t+3N zucNVAv7gTAI!tTKXsCsU(#CH9|0e~?P~}Xd*ip5k@nI4IoDDi#);@%?@}q^Pt4#@# zuZ^t7IQi~%e1av@;X5=SqI+#~4KHGN7NsZ^cY0{xI@B{*Iq6GmNl?M)tc5LTVznW8 zrZf_Apr~6H4Il`!cG*n*`T@#Pd;ior42(7IY69CkuCYO9aMKYvU<6t%mt9G5wV}EH z!JtiRV9P)21FuxH==&+Y?7>4Hwr_>%CYh~Qt&LfdT!6U zt5rSJ`7GEaP$_pD@!Wl48waw}8@2BMDiqnZn7k_!=6Ca~DEj2ta0SxAb7-dHj-hId zkC$Sn7OrC7-AI8Gl!<3?acJxQR<@efBctDCnz`X5?P3USBaXTb5>!M6u24H879V8+ z|8Bwc4b#E2gijr+IZ@G|Ow#Ff+{#cRov%T^GJFkrpSnwzlsuP-GPz`cwmmcIo26rle7o$1_ z_%VkR$DGLgH>?F|3Yox&Bv|2D@&JYs>Qp9>g*V6g9|D0EiIwyr1fU8k7-tC_P3cjc zLM>{4Etym}*)gYB8XB^)cHkr!!Nplu2UkH5`~h)sby9SZ67Ne2En<9dc^~J! zbGYw5K&Y3QW_64Mnr@q^L|n{dSH-SZgdk82qeo_zF(*k$_^z*e1o(az<5~Xq{#?Ck z&SF47B%Wo4X%lY{PjA`==Y8S`E6OVIIq{fD4H7?cUGeyhbHQbSXGY9)YMwYkEEd{W zX=7G2HR37asH*9dFJwGcId5^+%2n3vlfN*W(^r0TpPFT|f9A{GP3qpPcZL!f~MU#c@7HfUaGjQFEN{W5;Qn z0KsSAN^kiqbztU`^jb@c9szyZz{Pb-Q}%$%9bn+ekWJZ@f>c6254@kzH)VnTTOhRN z&8>Nk(+40;vr66o2Zz9DfwI?q-re0kw}0<6=l25uA#!;(q=Ie$000JJOGiWi{{a60 z|De66lK=n!32;bRa{vG>`~Uzg`~j}R?3e%m00(qQO+^Re3IYlqAt?g6@c;k|+et)0 zRA}DqT6uI;)w%zD-#+);n>i#RQo!>X^OO=(CIC}Ibz7J7KL2C^Wxu)k=tXP4v(lR6_CF1jLJK7qNW|D1N{xc`f z9L7%nzOO#`(kd!j1Ix0$m-8_*MC6*z*Q!-3#KgpacI?=JzJ2>%^GjV#GBY!=Z{I!~ zJ9g}<`pChEn~@w8+%Du1?-_5QGuoBwtEijjigrzeqF)mciD(k`2RG zO%q!_5fOk83C37vugdp@So#Iq zYfl)k{I1`ygNHRgBPkm1Y~?CTo}UrZVgYL9ZjF>hS&u~Yjvhz`XOZIPYfn_e8RdGn z?wt3nzrVd}@9|Yf-aNPMbd2`vVPeZ10KUt1iJnHVqbPCi;D{Ruka43nx@CN?wosHr zTJp4CKP9&I8=;P>;~yW%J(R!C27ogqCo2F{060!eAG_yePUj$107Bn-kYXL7=34;9 zwKxVTmkI;0M}SU5lnKDGeNw5f0i197GA$?{w|<|lwgy3Ddn1t)J>V|LOl)xg8KI6m zHh6|I)z9oYtDtm@=QGh)cjj^Z<8Wd6H|MPyYb1cW28`HLPq!F=hzTi$+2iaQYdU~a zV$1qX53k#Gt?T`MKep$6ez4vrHR!yn_p1r676FiG0Vx{6PSxF4;Y|{ej^e+*w4!!< z6-3LD0EZnK(kZN0i!m|nvp0&ZkCKCS^$udXd~k!hwumSe+BXtpa5+2p zQzbZZXyN`o|B9advv$_B9o!@ha<&KtUHrdFSQmXKMN7G8DUBK0?an|ZT&u+Y@1XWg zGIXMf7QrB5f+NhFsQr45@2uD3Y8$*vv?KwzE=OcWI#UBz_|Naq>Ce}^!6q*cj9A*QpV&NcL&c3c_YEsnkAHO7?ws3WmJ#gOAhzDYrh2V@YHI58 zoP)5jUPO)o2qLCFq-Y?3peFg0rC%!^xw@c5R`UQ*I>vK2b!qk>?bm&j696@MVCn$_ zJqZ6a{M*{nQ;Y9kJ-x*@%r*{R{bkMZ>Q6+=O8}~_Aqtk-b+)@?RMGFRH#PXtF{nXy zJh#cmA^e*=XQa$@MFkB9AstLVh2W<8CfOl~Y+{=~+%Rg@Cj5Y6{7@U~YK_g0Z#WNN zHZjeCvwK5w zFp@t=k1rX$rkq@ies8$DzFR%Fc09(PY|bDiMf0MQ+vYYIbns1V@MlSHq-V7m8WTAF z2pRmD3NczYI$y0391}W{nHF??_%;=&C-%9%oVqk;T2Q-?qim`rZ18+XxVblVSx!O! zcLq2y8c29^;N7hg+a6L*wVD*YuR@Hnc7r=^aWBfcJwVljpbjBR`n;ClXc*{_(SYpc1vUrFHl5)01-fd_Sr94)&5lN*JA*jC4(2FEY2Pt*p6|v z33>=X6q{-i0I-wiA3c41r6SeT#K_>u5L6MCc~MiqT1T22sA3gi+;L`$Xi>&c|(vD9TbuUl2$ z{ST14PNr3gt(UXnZg=D*cLIR0t`Q?Z)Y+=>kA>i(5$t&SicyRm{44|~POLmKU)%ak zVjAdPklC-<9%n7Doee>qa;hQ`bP1f?-fLffQk*7QX4M?O_yGi`l&PMnrx&BC5j1sv zCO6F4|EwNsww5+OykVAD`gg=Msr!HR{N;DI-2i|Zjj_wPTZs&sB3d@@e|zUgqUAkO zG$v_odWVMK!bF7Rc^N%P(GU@-?J%^X{ijrXsfDq7|k!Nzj{99P5VlU8fPNzuKc^^xt*lpg7Dcb6<; z>H~op`}7@iA<_$gX@nY2&iv!dOJV-r-b2x0Y0t?3>WJttDQZ{WnA$^%ewOrpde?K? zPWuBlDjBo(Z|;Q|LzPo`0E7TQ(!7jYNYQ-|>}G=_0YqM^sSsnhjHk;U@cd`$(wteO z=-JTj;U4#*tXIU+7nDnl0#UmQ2P(&iUr!^a@|fY~UsDRRU(mJofuxW@sbtVJ2u@qo z{puGsSDjl)+>Y^;Ok4vsOyonNS6V-f_fi3CI|(J0Mbv&p?M zb0nK;#ig*So?>FFx}rLJ^F!-TTuN#PsSC5;C8qHSv(tX{&4%Ms+YIR7Bt=gsr&WNR!`SBCB7(JR*8%`0W=#0B z$MW9UmiAQ|wkm)n?{9u+{fR(5zdiNFi%|4&P)Lyfm0>gOF&QJDik%lf({!0>hB6#A zwY8t`9`pIGLeFBrGZlY$=MROa)2dbmwGGY+HA7mT_f{P|weH0F_a1!5s`X4kUS1x& zHcf}7eCJ`M1NN$Y>DcspA|oSW0YF)55Bq3XK!m+evE@wpxOOAGzxF(XX(#hev5)RS n9o98_@A5%~Lo%7D0s{O4_yF;NymWiy00000NkvXXu0mjfw|`wS literal 0 HcmV?d00001 diff --git a/doc/images/anaconda.png b/doc/images/anaconda.png new file mode 100644 index 0000000000000000000000000000000000000000..b384961b79cd7b3787dbe05a9ae610ce5ad5eb9d GIT binary patch literal 39373 zcmdRV_dk{YANO^RgJU~JHpg)!$(|?c;FyU*B(u!2$;uv~tYh!Jk|Jaar8wD=y^@u! zgzWqB`QDHFf4F}*oX2@^uJ`r6-mmc-uNQZ8)Tzi>$pHYM($v7|0RW^Od_NB-0e{-! zXS2ayWUd;<9snSCdijD>c5bEvKt_-zPRYP0do#z|-}K9V!(c|!w0GUQSq<-@d39{; zZHft<8WcKWgp6DlJ$7pI5Qst2nc{3 zJc#7P;jl{pU^(yN9{B9+CV-R~Bt3A#VX;qxX|xr0cqqO0b^hns9`?@85miBF8$*+mWQVifT-J7WM5s@F8R82c4(EtE zMgHz7Xj{f8`T5CadkkPioTelix9X>n!ooC(U(tP8(lkguGY z=ZEc)CAZQ-lj&quKLj@TX=l>4Q?W^POw~Drcu?#0Jd}OhKlG591kRmbPykW*_k-wu zgEI#JSlikvqv=BKRTx67v}wGakvR8@a4nb34X&RPN^PD z2SWnauih!8L)l6`*jlNyzc=3Xt8(C8PS&(h7T0I#kYM}xTCILcGQfJ>cW1JebZTnK za-{Tj2Zj^}8^F9VZK9*S8$JI}`E^yUoUfR*jm^iE{;Y@HaU7P$SA0I0`>f`8cS!Un zDD*$%_7r|#Vgku`q2KMM6pN|e&atPaky%BC1Gvq=w(>ZU{cQc5A`*Z_46xL?rdb@d ze;|Ye5TaeoiG{gEIKL2K&6RBn$Mg2jJ(WSa&CV-CYNdXmziJZgEqn6GxMoWuMd?F` zLb<3YABqZeJ(7wRgv90FATAtv6FuHN?_e^0>OOtaFYM|3_@rY~8;q(5N~*Mv&uPgK z_trB?b6o}xO`_))lwNn*#V8G~IvE=t%FgZ0Q`_ksqCDRQn0Y*iWa2N$J>t0f*A%WR zZ-F`<5>M6RTo6};Ep=!jg5Mj>n{ua(TjUx~R9Aa?yRf;o7vN^b*sdT-6K=6 z8A{IR5*091UwSmT(y-3uU@`wER|5_K@{@n|nB$VmsS8$5-hSEHRpuW()Anl{SO3{E z>wyvGxcObl!#WI37zU7NO**)>z23?MnbGZ&vEjp@M;-%rkD>tpzYtj()Yosn@|I$o z?PTqR1ppAgr1g#8p{3QNBP4Wh6y`XzogMf}4sN!6Nm&rf4&u;o2@md|P2f=>!mvg~ zp*3GL;%y`O-@JOB>$vmx=SXk6R16z8TNhI0ML1nH88`Sa(yTzZ8Oir6FyP0v<$M7g7rUWxTz4qF#^t|0Kh!k*v z--36FFpF4P%1+r!s7pLsDZ~7HXJ`LeBT)n@B(+lbn@Q+jXSG0SfcIQ)#z@UzlIFra zGBVkVnl(d2^`KPeo6)s9Gr+t~&^InTK@o|0qXOTteFe9rb% z0RW`x%P5%p(|czQCqZKYJF_Jn`7ixs_%um{ad1iuYaBtVR*-+DZ|kk0B-2a%ql)Nr z_fICvO!2aV@TL$L8DPnkS@J|e9r@r?LYRn^6ZpRpUcb6 zFRrFpo44@OPlYgbucvKeyAS$3fT69SP@SP?YrZLIM=zn`$t7sZeZ4vMaxq*3D$+c>o5ps!31eV+c-lG6&Voeoqb#pc#0g-W=Un2AMHgN_ zrfKQiUbAiC|B70MhkTS{XSJw0qauTLV~7EWX2R~uiLxqRO5n(P4UpmFhl^lmz>hk= zJhh6h$_rZkCjDuorW=ob!5*7iDOm2=VSmqPD|6?KjDk6n-!S6=bC?*>DS?q5ucV=! zsrO*3^e4os{^b1+EA`o?r@jkhCASg^(5m5-zKYw@to;uouPF8}&_JPx=T~jxGCZe4 zgLGOR&-Lb-&O>W?xiw_OC1BLBm%cL5<6VZM3Jt4dKyduPvDQFkV_(kJTXX^1GXAhz zO5VLnI5l`TOx0v(j#&YoxJE( zUm_%gpoz`U0VN-c%4X$u_Kg*d|5|n}y`j@XbCj{pa@of*KeZZXpF@fu^Q+UeWdhE3 zQY-sR4E;s!+bPN}tgl^k2Sm6Jl(2OpNa(5D_RkyUnqQrNS-0@W`xv1GwS-T>QLP-oEm9Cr9&p!9}F(5@?m;uyP zG%BLY-}dKD4bxK4lMR>gfbzw#dM7a15&yx`KXH(VgkDAl?CXk;NeyOhtnRIqU1zVP zFa&69YP)FaB5UKd*4fGIbJnQv-9z@r=M;lkLEIj z{P|JS)NWeyZ2Zim-F*K?z2Db^`(LV4`q)z-e4JLOst6lnY6sQI^Hz95uay_8N5TG;+4I_V-&UV10|ws&igO{>tA z+g~K(I(#2XsJt}a8)1)ow;gkB#%&3!jMPJG-qg89KIBGREiY!=dUcAmD%J}E$T$RG z6X6);Bh|$!TW6B;`|Zh?{USVinYv|i;H2KP_9WzOJW730rAcusJHLK+R_tPIhWB|x zdcT|KPdD%SZoUUwnh6-56LF8g>ZP2`Cf_8#z5Ao~)THnKJ*Jk$7oku}i7?H?VvQ0p ztDA+te5B|C3@8=f^Z0$aRq)aB@XRFeTWjk%S6_eOwzjl-XD;G^L5dD-`cT@%=j}mR z%l5oGQ@+*BEd@~knBen|rf9dT<0WQ&(_lT$11bDSS66je{Ez;gnQ~=q?M%v(nR&FY~vgy490*J;T6gOr}G0(zi=3vT%x2&Yi`-Jw&R6 zZ90656UkKJ6b#H%p@-p4(^W{IfRFo!?iFDH1#vs`xhBH!om4%;vU_>a z-8$(RpqZMBHB>)%Rm^_J72|G=gALD`QOTaSt{M4p&birb7k<*eHOP4qJM?+CnlQvM zJG(ztDmxbV0u2!}_+HdQ|E;;wU|5(KwOBq?a*{ zclF<#mwa$}+ZK3acT=jt$wQgvfGlrI@cN%r1|iOqHL}ADC(RlP)`Xnho6DiK!b9dV z0}@55_h~WiwY>b`ADJsI`t1dMG}VS8Hrh5=`?)XEK|z?HKPAhj_gO1vq8AfCZw12v zzrMn#V6<hDgIvh|;CSY$@Bat8T>IVb-vN8PLVf&~K?5d3+U@-e(3T)PUa`S=v- z%R_Q-qfj`5*iuJF`ACEAdp+@vR;!mQedm)u?LrC$;mnpWvl)2f5WF#6;MY$uoV@Lt z!PI8M3`zSq69WOJwk6OSvHxze6^2Ai+P-Y%eLk++ZgVc25v>VWJ~)?%ygs#Aj*)mL zYaBt6G~%3nG7?I!B|PPI7-0bckVL;80-oH3LBW0!LzY8rrANN2^fiZe-SuE5efIIQ zL+SkVsaNM2mGI7)$=>vS)yr43mXYh8b`D2(uBp)IIRsX|Yq4LokPfKz@GwbXtWG2h z!vLhFvk`#<9@k4S&@}!bK%M$*FCb-3&ZUHJ(C8t|lpAZ=qnEEuhSq>Wxu+hIk$Uma zWCuX+|BWq|9@v_4-SC@l*kGc5rso`MOSn9-XboR29k4BRVPrQtb*dF4>$6;&+QEj0 z3}J0O(_`OOWEz59acFWnSfjQ>PiqbIl*h7XbXVC^E+%nSjw)x-enA#s8VLJ6Wnboz z-DYbI1MZw?kJbj{XpCfeavkURVV>=s9f2ia39K;BS(0kw_S1-&n#Ij9T2MkbK!;8! z$f@#Rt|k`Y|C7czulM^%^LER5{j&G$q^m@x3LK35yShCOtv7k;g)Y&W-Pcs>Z%K>4 z*<(2H4E)>TKC1)u4U7BLR{%?FarUNw*d5Lu%G-UWH|Xg!66g|d7!66j1&{Ae<>}8Q zi_Npa0T-ta4xKwG<=OWL3IloEw|FPE&nWnl1OZySwg$`W$S5uN8v18D>p<9D-DY<& zJeX;7voE!VQA+@l+OG%w^A>zdrr6Bx_^3J-Ih@z^6ZG-#DOD zgt-r+YE{4;-218qRiqYe6=!)6!o_#W!dPGA?7o4~gpg+C(-dm&XJsybiqrf`>sy zS`Fm8v#U9Gpo!Kf6%mspSsqzgMZooQmqO_e1)`djr5wwDRKI3s&c8SBw0V3#nSb!R zJk^_R{+^yF-}v(G&Fl&8D_%DxpRSW!<#jM0e)p(9R$#!wZ}H`GuxJc6iO;fQ3dM=o z^bpDsrdLZ`?flAaAB}CcWMn}QYXH2TJ0$n$u^MkwJ<~$CLurt#go1zE=)=HQ>kv(- zOPgTehdv2y7~m>O0#H)HRz5lNBUQicvJIKDnOW8I(gPhKSX8aNjQCYM9bDK##I#Y~ zx7|N;OvxqL687mmBsXhp0`;mR?5o-#koT2Op0x+4Y^*)!OU``&21@&%C7-Lkuabda zn1pqh>5_evkiNCfV4KfOT4`I;RwY%tu6#?u%~CaSX8YeGF*mdS$s~Ad4}Su|qf710 z+sbj21fh3YiD(2fn*Y0%IJCt80a2DkroT9z5L&V2__~5q2hZ3G5-+Gy41+3V`+~Fi z5W_gDR=53qmt5MaXiua{9OX!ac+54gHw7@)?TnK9TORi=oSs!|r~FR4 z-Nd``@6;s>Yl~zGRynLJ`q((>th{Dy6*N{;xDY!wdV515b6Ks3#p!a9LfMr3mG7n{ z&wKTQ1n9pJli#rX@raYU?+Xyj?D0EyvQAHjuUMPd!NvT@c*(x&Ky50a2(5(u@OA{V z(zLBJHF*o{)wa{GG}Ts0u5+tim!wZda7&#%*}BcwrRyuUkr>)O>Mf=7=qh;+uuV=j-K?qg2zdY6A+mNjGFh;`tU}f`?gkbsE1o*+RQ7sN4U9Mqek-O|zBvYWDh8m{2eQ-`hz0u_va4F)!;hXnK2Y z*Td=jdyMU6E}@Dm*Iwb!f1m&Fz39BX#eB0|&4<0^xY)N<-j2>DNv~q14%Z)Ra!tnz zn&ZlYQi|aU$$#Dm$A4C7wT@3RR?DaGt>H#goG;c7hW8CkeHzgUylSPKkn_oDXdIHj zBQ2|aP2O*_UvNi4Zn(MmakNy>6Ey)4G89F8&&-OtU+rBL5a6|1da%WHx8C6FLoy%V z{)(Y4E|?O7f@;>NbUR5We4mept%w0@m}y_u&SD!XC{=dEO177M+xf_=N9yCuOdXhRXqwE8>_il!p}lJ#_%Fe(*(jB=sY>7eNE0x{+-Bc@Y3=CBtlzQt(PbZnl(;IN57H2<08wj8!(V8N{;+YoZ7HwzoT-mbQFUxhKg+{1?Npl3ZRgfYX z$zn26R_E!uvF%U0KV+S+P>N9tlA`(_#$G0V0D)i@L1JYoqUCE<;(iEIU9r>dc|e<+ zTX)0i(F)3?Ac%Z#7thz->pjUF{#p|bLU5#io8ER2U=*?Z!@&N_63 zk*W&i@rgN3f&mV=-jjcS7KHoT=c|xT~8(f3Fb}J<6{uNqR-PE#A zb4rIncdz2Kd6&#nxA={UnKhxhscAc>(<=_I&p(`D%tC-0wj$%Xk=}Ak0^~7*7{ATI zH`Myn>r?Ddo?Du_mEKR*Can}dEDC3xzy5zq=m24^P5*#dc0=)Q=QJPbcIQiUq_+5YYTw#ZGa6rg90ne}h^}f?n6H9H^>L{`W>K6Haf%Tm^&x_O+ zTF)&E;73J^I7*crKbx_*lAVeXU-dve3b^)s=3^-3-7elGLBVFP!?5Xdp+)7YGcdxZ zruD_$Wg?`&inn^faklx_WTic`;mtqRPL~)kUDw>a&3*f!k?<-Zp(wEi9pq=8-^^%hs=oL0RSZ{J3e(1Nl2Z4bK+hon){T_$YIR`iBgV`yUVpP1untD z*4Y_>2>TfG9Eap(V8}B(x4~%>|8S7ti31Oz21xzw3vZr@1F8u{Oe8D5QNYu7k3O zu`toHr@3i?9V!cu|17ruyL|-9EL@W3WfDH4p~_`)_}F|Th_W4W!mS4bXkQW1H%{lw zG$bbz?OiOUPb1~yHyfuh1Bmv4w76vYp30MkQ;A6ii(b*V>_%6A`Pm-W%`7#jp1)FzD$j^a-4cz|GwsPV2d!oW!7i6 zEA7rS`rbZT_x8+qdl63L`a&It)(@$Y5t&U|q*Cr#BL+0+cv-F7X1`EVCZZuBqu`D> zZmgGf#|D2S5u~OI;V($$>y{PubDk3grhco6|M^AoiUpA#?G~gd!m}?chO^ zw{fQS@@x*oT54SvgI0d)NFOo?j`w_BJiAUxU1+uyIM^v0NEt;3$WYE@8$%FaU+X%D z+|n(=Tmu&x!8FZ7@6Q`@`glI)gkrmH@pX#Q{L(BrI5mp0*Rc=+!7Z2ehdPGSDUslVkcnwQ_t zXzi6nVQ8)WH8;!K_uP$_Uz}MzLYs`(@1l02tnW*R!PG&yMxr%m7x;gMhd<~1E}PwA0-0Cq4*z>)X8UO*phVtDEN_(pOI48y6RnlXS@SxG{TSuSPlhPok4zvgjwe86IdH!3zk_tWZfe=v65HG2^|bZ#2O^Uoj)?&!DAKyWUv};YH+U++69Jz3-H>`8 zH=4*8%?!Q=IakBGA_oP)D7+8c8Q7%JxHfTgD5$6mgB&!|2a8XOLxrf}{jwKrM-&y6 z7J-hvJzb+fM-JfK|4|trB$OA2N2EDy+n7^_G44lF-Z7>VD_docWGC$}em%?%2P=4I zgw+b_BK^(_zIjD6GFcb9wI)VsOs%3Sj-1wa&t!k6u5J7;AI}e{O*tCKmmTSs9c!3r zKi=pq+iG^{oj1^Pw-@2`Uw7?Jz-9eqj+!QUR?YeE* zJ58mF?EL8A55Wfy<)j0D9nHj*!M>1^t^B(ch-ldp$+KB0wRuzirje|~HTj$x0Ra2m zBmijPnmK40m2UMke5zYah5566(?2Sk;RX1-zj$olKm8nDQ97SwA0sW(%~BvB=Z~}* zY?~gF5^lKJce5<~J2G7XqCp1}5MP<-9;Ue^E_QiI8n17UAA@v7(rQ-mjg;4)z(@Z= zhvMZRyNABLvMm`uUHN2I6h%259iBT8^@1J4uh;j*V3~=}LjExA*4Q$8N|4vy%Lt>y z{ugj0^}zA>ggE)oAHSRL`@OkY^GQd}b(i{j6d}2!!-y`uCm>=A@_IE8tOn6}w)Zu@ zh_}?r6IW7tvuLif;Nyx3kFu99rK~(60AKjM9f7G7v8-bG3c`D`y}*FJ3dS^@;}22M z`C^PXFV8lJr5^+JVI4;$vVjx{s;G08LQ*mh(QpmD5ZO&3ka6+Iq1hK7p_FY-Y_avh zT4Ip_sv+(?masQM(aqy4@MQfrSDNd+-*}JxB^)(u-{bCdc_+%G+gP|6>ZV8(j-O0r zAPlL#%=gfr_A3vF{k+VPD)VT(x_oO^3M z0}?zTQhrY&7+peKVv@eU)JXinIq<@ACzH+SMn$0GY7)rN*HnljRfuRoKEs*-3L=|R z=c_brAfxOwco+3$hLUc-ItRpJd>w6%>NCUmB=J$>-Iw! z>t5zbj{k0KO^m9&BAp;4)Ef}NUHGW;c&AFP6ol%dgkhv6bkox3Dfv(#PcR*Df0Sr? z`|mc`l!XIJghF+s|Mt-@F;jx zad>E~;sejR$<&!Bl0B5luc={U;{3&vh|pwpXMdNcOztEi8Vl4~bUqaRz9KSy-|Y|B zF7GB30FukaV(=7&_slL#KwPe$)c@O1{h#&vMaZ8!+cOqidj?-8tz0r)NB?qht0R#l zxNvsgLeaUlqwx)e&JV3rv;e~v=j1Kf7`F&|TfZ7xJnC%ogBDb`QkV@}HS$D_1iZ|{ zFYuVvHSEU?}q<~(Jxuidg;tQWBkQrkFa(dE4;Zq zj50RWn{~{&#cnZ;svVLuJ!O(TyP@+m>Q3n!Z|ChFeU*sdVl77TvBBHooBzcz?4@5L zqAFgoVP)FI;zu}tF?1Z{!$IncOLFx>FtC%+OP5ZRyD_q*{{WpF4Q`94NGH3{R~d%ydy+RXhGeZNg5PHt?1sN5^pm85*)(1bR+ls6B@$*>L&w=U%gd zhGiLa3ZsJl-B-92x4N2i3E}XB8ar@T`Ea-2i-Qu?GA;zzQ z^SZBcs1~J7#+x1)pXIgJ8syskQNzcVmsk3w{;#=Vneev0n0arF%sV;g$qc4EJHOLyPW|SL6)d}H-jY&M?<5LyF+QTz z|E>tcYmUX)rpol<9(JccUa4D~Mu2zqlm#A6BK{`M`quauwQNPSj-Awd8DD7?`cD}I zH@w{>!^VXyYyVLGfB)S>r&{kVvD4~35^wr2?YX1H!c<^GTCpiN)8_F>YE|9a$8+J7 z7({}OgVbNzyLZ>C!fI83r8k{Nx50j)Gj3Z#qLT`NnD)FF^OHKY-!Hefq(A%tb`gut zf6W>UhEtAH>+JQjF!vsjr2wl!N+o4l&iCeqXa4ox*V5do?WSkh+Q?wRQ83sKYQNLP zs&uaesX%Hfu=1UM>+nf@Z3p=oHo*@p!e2!A@BqJj^Zr-)&EfDUd=!KGl9PjpSJm-x z)~&46@%hC8cQ$~F#*j}HMDXu6xJ=p^CnqO=$uuE>j)_}WP0d^#e+v>2mafgVx5;5J z=vt3!SVM=Rq4AR;+s81V*nZG*^Hvdt_Hyy0>Zx+RZKfdC4<~29Z=F4^=UYDJYCTDD zI`nrX+xsK5y_^ji0zd%|?0|1dH^|o?D#~6^0hEMuIS<=`&!Dg3JlNN#plT?O_&_T$ zwOt8S;<4?hTa9p!_57Aa)pcfr1&!%{6;;dwW^aC6|WH{PnqUcf^=edCxY0 zU1qx7CKZ|iL{gQhJSPkcjQ+9(qU`IlP*!>? zn7zO{Yp1UnRs)y1EC9e+#A>4G5l9tkSk14c8k^^N*FH?GOQbCNMxm44?Vi1?SzQ*c zk1v|uS6YnH0~Hgi*nELrmUK$3bRt$b*oRgC2X$nzwvEF%;kGi|KrrDT&O2w6Cyt#8 zL7#|b)IRTe!)JG9zhDvYB}9Z@&zgs$mKg$|GwJt1EhhJ=U1NeTGoL0o{6-=S6HPuq z3Ls0~+LK0$10mVZiB}WLNrbP^goL3%y1Q*k0WanIe~Vifds*6m;MRuf=O;>Wl$$J^ zq)@Tx+ZpnT7ZmEVb&mQU$x8SjleLN^&`Ke44!0i5guF%L)t#K=_^JNdpUiTOeSZaaIgOiweDX=z~7uadRj^)h|> zPnZYXFlTd$8^?`{NRw)K27-njv5JxRrYze~9?O3@Z}m`Jcy9GLIM8~O6R}*;cD_QE z6;yZd`4T#c(lDJr3er5U{;~?1gowhJ2Yn zIevV0Uz$iTnQXY`Qh2ZD;;qK|^7|?Vu`I-*^G!t+*iBlo8 zDoB7Mk?s=K%I`AEyxLoj*?M)&FTzryO$;UFO+eXIm|2Mlk1hz~qy}I})lfgo5|{p8 z;_hJ^*KIVUAWLaAwcZE&1ST?*&y!VuC#Ty7{ZNBs7H8 z>#c%O>9qLHK#A_{#%IT#Y8^k&nt0Un1JfY@I#yHjdv?^qBCspfqUu8rqx*BUUsKe+ zv|E1Wpq^VEB3eK!#yy(vYmH*I=C=>1T4My5-#047u3#@Qmtphx_LxMKj>J_j5hfMn z9c+GjR`vbcOJ`l<4tX{0jxI?Jjl0=DCWao{tQwxKziNp$v+xRB{t_{783)(0>=?2n zfxH_#%&h;s@yJZRVVsiIZ%iHJUjAmktm4;))+h-)EKdK};dsmW*VJ4!36`|MLG8+A z$bRv9v|t#S;RqV>EVh@i)+b)9flS}q8xP|_Crd(+QX=}$^_qOc&k8Z2j!lW%-t~6h zK=t!t?^R1&K>KEB&n?RK-rJz(ffo@X8pbj;<7)|q7x4NQeMuf@oXNgQpV*N>#(4u* z+?b|v5;0~65_u>Um8!$XYINY`5k*H^$KBar(?`@anfu1c=k3ZofbY@04GLj7%&!Rd zQS^GEVNG9A;CWyB`dcT_2# z@xpDNzQP`PgU8{TQK}wjoke-g2JIXa%w0AZAOXHU< zcK1o2>Y^gs>pmHstHtbZjJas9tt>EC#Jzd^tmWJZ;P{!(;`_8apC!<7r0k(*rtS3W z(#WQHz=pr#{Q3F`1PJB^Li*AZCHdueHyl}*xR^v?aqJi}zpL;V8^#>Nz<^~NeP?M{qGWZ0tmXT6%Pme8)}G|yW2QreP5)B6fJ}rs7^67X>m=_E_61EI z`SDcMw~p+_i1Fr!fX*lH)b1=v;T?T3?iM?@vyKe2{(!W-mK@m0UedY6@A73k$kVI? z33+TB$+g6iRm1F*hWGyj49iML_7oCOUjo>tq~9w{HkCW^0lNlu@K{?ksCKi6;wO`u zUI#g~lq}~&sORW{D!q&5@QYl>L7Yu_$d?T6{6v`9^vz9Hx5VTH>5g3%UMoe}-#Qj< z$Skzx2a-k_Yl221JccU8AkeL67In2b?`ysi7&vzh=hqc)cC*i%IY^L3ZNnva)Si%B zekeUHUGGc^XkeNz$zDP^o5D~XcLW@(EQx8q1%+$rlf>dHd_afcd7u1@&a<}^b5JCb zAIh`z2(%L%8m)YyOMD{Ip9A0${V-}Em>z3w-M$B4QHf#oL9r{y%dq#QrQIyHfy2I8 zTFIE7#3GvpZws*fo_$9dg=a$%P=cxg5gy#(;KnjyQ zYBp}CjO&%r>?(2nc}ha17Ys)H$_umhRL#d_;v>#NnRi1$!ir#3)>LS~(UyI5#Mi96 z-3s9|0Da$=cM4NLRAFsOv$v)icj%@Vix~VzH)pdXTI&$+>_mDmHWeaduYsPuD3}TH%@?f^CzMc0iXYC{=r|FFlvCEM%K)GXZoG7x}7oal)6cs zxPT5G0GP20URYEzD&-rupkq#QyiM29)|&9O^z(LoUH{;HvsC2>$iLmp0+i-&kj#xv zimfhs%{9L-gsSPhzKSNQ)aHeBr2jrU^0pJ~Y3X?)&N~sOs2#Lh#rr#UBR_7s#P za0bop3APfKkLtQo1i^HMmmc+dAqWU#=|e*GH2IU<5SsN@dBA?IMc=}$6;}Y-pn5kf zru=L#NNYWV8Ok{mT?qq3Ijz~Iyd{M6OpNX6pq&EV&tM_Z=a#k%acCrZC=AW-U7EGi z=sn^M8WuM2Ji77}v9|9Ez-zK90eAG<#aF$s*S04n4(9Fetw!9xXYE&m87{xU`^_x`sfMT#w;Vmt`C;K!w7iD4X7eLxJB?9}de~!F(X(c%w(> zeHSrn82i8w(SOKjdb~m5E_iAVA;8F|GM4`T&O!EdB7?o9?e_k&eNmTHLH=V2{t_JS zJ(_h-fe^}lut3UPg0b#pxR*ZXNv2OTEN-~+v{t!mhJ!m{D{UcyTz(UeAA48xw zQ8LdqV1;H;sSF5Bx{#MKpESn@{-tEDw>4Il>afh?Vew=X)I%Mbe z5M%8BHhmqbO3w1<(iz}=>bA%T_V#~+N1gds7JNf$EX`kuS=~MVh5}&tQYalm0$IZ5 z{rNvL$0qY9jy*G}^3pE0UCD1KaM}{wo$(}n^tAtmzW5iAv2ZV?YlgbF61TDI)=THX zeA&z^j)Lx8Or6`YFwr4cT){RdF3x;5V=V-#npiA5vUt74b;hV?pyD65Sl?YY{-%XT zFp1*dt;LQWs|SjKuD8d>4&({Lu5yl&nRlZv71eMCP3(^r&WQs{Z~-M43i8UDo7(#P zOJe+w^C%QW^p-$TC{66HT?Lh237g3{YH zWb-ThU3TbO0_Q+K>Q`$A)S6>=b{}t#u>d=3i6$8*l*i@*0xW;sKL!~r`F2n_yP$~e zO&5F=dfC9KOR7>1$@-e+m8>%=4e^i%d+@ znYnwR)jyv(|K`GOb?jiPCmD9|mmDz+kPHb%xB~9>6rv~;6iP&l!}XATywlf@l;HNz z8~$pOwj^}@Un}3BwAyFljIB~!gp8vk1zEB_#k(hq@kVA(>iKxZteXw6;M%^&YG{h&^^gqVQQ|gb76Ah#|BzEOyee9Fy`uFaK{85bvZ@= zOnq`PLUN+sU}m1&`!h(BOv+Rk-GtK)o=jyoAFdtLgU~LT^VBbl-DC1}+IZ#0?)=`n zYz2S0(C>>31(K}uuL1%Y5ClU$s;Ma`m-0Klxk&*oT|*~mc;rGLxh6LJZB5*jq| znFl>a3YfFK1r=JUZdqI!i#pR@s2O9R$j9wTB#=;&EK=f5SL?Pd`loR|WmSN>j|g^6 zroM47hhG!0HhOnQ5A-|0V5sHI&?uSW>St#v<@gA&K0WOOh3Rum3h?x)Mdz&yeUvgK zM~O+D`;5jz0C1q>C;^m2f`f~LlZx3_r|mt{LE6(vh-T3JH96(IGq3j>LpvsuJgTsl z8tGznnTHEIo$3l(Yrh^DJ;XB0@05uzdhBr=I=4N%yi>X+DbFX=PHMfqCBu}ICZEun zmlxXKvQD@mSG~k%>-P;8F=Mne=-KvFGo(EH3G}m8@WX-1P4r4vpqZH(rhfH$qqw~P z>W8e#Y0%N#zILT;7_nXjQpu4A(d8JiY9cu78u*pgqq+?jHY4{c-(~SYeg62X0zluF zMs0Z1)(Z(yX(BV@AHW?X1OT<&8PHE=F;ntf*Mk|XlpglV0Ir^1dsMOE?BTRL+n|*> zbtT>}dvAw-?Y&!G1>9FCeWOJ=;`OVE?bVF1al$gX+eN_{%HKR z|C9dMo$Ox%WcPuSC&!*nJFk@E9vmMK(OJ8-KYUd@Q+*z}Bls3nYigf$m0@Ov96&od zo4YMPrSAw^oiJ!E2<$a>0fj#?Zf}f!M}fN+#vj=doFve@wLUfVdnUqU{>-seXsqzV zP07g1?k1902W4klm!#UJ)6&94 zXoC;i6$d`-cLgnUErR}_-ZYs#p#nnyaA*iDd|@C;teLxIy#V3PVi;*$x7@HHTybHs z-)qkVy_!>@@6lMO-kI4+Vb`_#d0*+Nu8~=9Pedf>%?HB2-%zVEh0tmy5)k0#76q8~ zTFjqJvDI!|7@di^OGIfrSr}m-024C%7#=;d+j%H750;Kq2(~U^L2Q$iOx=;rwuL+& z(ZbZ29Zsg8-Z1|10idKL_FBj0@kSN?ALK9ae58DTD{A^S6 zo@BiAXB1md>$-`g?)LFsqpA&PG5-UH zi-WTtJD9=9$%vt$2oXx2PU7F|6ZHRZ?h8{{^G)xqRHmIxHh*hKK^)>NL0g%}QhpR1 zdCx6dvemGq@OOXn$X*(w@O6yGrZnj8%=t3&i%t}B>CBc{4x^4`R)V5XWH7>Pc9x~s zo%E>6O`kW%Y7N|Y-2P_5`Z>z~R3avb|=n=w4du73yn!)zOPB(uO;I)W;s|Gb@C#aT1xl&UCPSZaZUf z;lH0Zcf2Soi$FAfWO3z#1b^cc1h*E1T|BAD9_m^!&{^iZH8JrltV(7#Y^I1}u*bsi zbTa0}f>lQJQE~}3Q%sTvUG{fqH+n&#ng?ikSGge0>>>IdWq7$Cy-osIJM;TPsdl1j$eHi8!e6Rp*8Y5uTlV-oA{bi#BI%1mkr^0xKJP%VNfXf(+3TLhkc$8< zUplVZ_SLeR14tVqBO_&qrXG-=U0BPn1e5X6=f^0k<9qH~$(D(1uVgoYLv&HE>|N%E znpt(dcm3eb2wYo0(p7ttt@7H_a7-0y+z*tvCi5+Z*P~ z$$W%2Yk}l36ALRL!74s;OY5L*pi9P!f@d=2IyR$+PPAgu@wOo!+oDNPOUoFTv+mQs zU4Ot;6+0$S5qb6vA;(v&gsHt7t2ROkW}HlgU%6D!U~wL6T)3E!^_!b|Y!z|ep&fhR z*zY{?Os~p$jU`0n8f6R>W3#wCagV{I8iA4eo-IrxfdUA%96W5s7sPCC#@shEB=xAh zQZefov%U6&R`hP#tJu-r@Ti`p?rbpxi+hIw2Z^8$YyUZjor+kl9mEV7qU_9smmyj& zH9p_qePn1ISj&GJyJ88ZWYDuxPdpM=ajROF7NJ)& z>OsGs@HR(qx2XK8dU9CR)qO^`Gx*k1IxI(oO?C^}~ zP?D(SHZ;b*Wya-M(-rH3i?8S6H!HvR?m+&Yc;Unu&J>l`dFQ|Z(yD(LzBP$J7>qq3 zjwB=jHF_q{Npq&qw;D=#x5Ls1MX^@zO?$i0G6OMvzU-7I+oDhiW?Ti3hgGePBUq3k zSO9qkSF9GxK9xfXHp@`gf!mPHmMeP6$JdcaA|Qet0%80Cms32EZy?qE(dxqEqnp6J zG5jzvJHW~M;xIOWO9~wh-y=bKx2Rl1@US6W`1tu(veMnrc<1jYZGCyYr@~A=e z#JET0$(7qcU#G9Vzs$Vj?ax0wtGzVpdd@ztRq?vR(gstMu-E8choK=I9VBa=Jb^kY zf2{q)Y|pHR_dw80jC!;+_pj(4Uqhwyl&I;aFiH$RNOu}!D}V#5G>Paq_J>!KV>b&MPD+PKU>vPPc%L-mCYw@z#`lA#?ePEepckmN=f#7mJ>i(?Sp z81v)bZgSt6adc?PLere`1=@SU4EM5TX40I?L$vY+Fa7Q6&#!k&$E+O+o&OI>XBie{ z_jU21K|oRo>6DO8=~9Lk5JbAAQ#ytakY?x(5grghK)O-7LplW+8l_Q^PmW*wjJi|oFmM$>2X z@X%cLcrCN!*_TC+21CPDaNqajqxh7Ag@up$9AZJrB7NFULfoA3#lvHl$LWvh^+83|0M|0;Cuv-b+HdyW zdsfG**M?jF`Q;gZ2R+?C1w{?TqWv2;eJiy=qm*}1?91XSgZD=`XyJPj%9@snWw2Y& zvr9@LX?pn`z1W_4^3N7mjoq+A@?M{u9#TZALrBDk`v+!6n(H=)`B&jkMv{a0x+mx# zuxn~OxawQ|5B5$*W~{8MqFmAC|~Kg zxrmjhYZpf;Up}}g%u&KBMeuMoih=f?0Q2S`wju+xKn)?qjQ+@*PMFEZz15uIDe$A0 z+NcfPVp@JKfNQg5k!CM|@b)W^Ktf80)VVO0X*2nBZdC{IvI{$;vNWYD%v3DRgFh>i z4Co?i!2ZM}NU0MI3rDy8j?coPtcuP66CKM>BDay|5@Mr>=a8a{$AQQV+VeN4WeJ?@ z0DJG)1ZXddhvvQkWZExge1-ob3Ze<^fa-%7G>$g3SO| z1ljvyknO$%X=(@HacuZdE{91a0B1Q#BhN#O9uI$e{tS&EB;3N9LeFIR#tZd8TSyWk z4Q*WKX@Ju0P0dN4v7xKQz@Ni!EpPCY1Q;iFYxe8-J!ejH)csdHNY4cdNSW&DLpxzj z9+Fcr8qcy#z!zwEF7jPk=WINF(Qwr%Ac34wU{f*do%HwFQ%Vl_@Q>=hjqs@sa=EkBCPg1hVERiy@=;RMm=KvSz6u4mD(s zQ%fiP@%dGEUH+;L9?(PiJ8j=A(J&}x7Ba*qMoANUu~UdX~gS*%Pcr{8!J zI><{j+q5sifp~+7i%Wx+5*57$TBF;Qdt?v`?$}1Sm!VR=JxZw{{&AF(ftSMy~#&sYOP+qaW9wO z@25?kaJEMh#0-1sME{p4D+|yXFEOa7;{{a*cM2nYAn2he^$3t0j>5G{s%hqCEb7M9 z9Fg{(NqAh5gR1>>L==Zl0cYCTKPkY$7)4nXD?g?&bw7R8VR+t~;xZc4YfAdY=Kb4u z?+rq^wWFe>DnFFAJVv*GB{o0CAR)ng3IGy%rC9d>fwI+1Tcd<+==xY9RZ3AK8L2c6Ngo_dJgoFeIWF!!wcaecBt1<*s z3&c+u1X9V3(}Wh)@gVENRPzYUCi59V!9#<~ptY2WF`<`PR1js^H9n)qb<)ZP$`3TS z4IXDD`oh(efWHd?@ zXt$g|(T-Z~4&!01ReCB5@TJQ(noD zRc7AG0=|lE63e>e7sk&LD@;=L&<)wXEY#7RER=Ax8OH{6HQB6+o)y-px|bf7{fF#( zvaw=l;G;b+z)_4T)+i!QSzMv3RGJc;8jkMT8kfXK#*ap&Q!_f0l`BC8@e@?IP;oVS z@v4Zwo)*%%D`4nBL1AnJ2nQFS((sd*Pa~!^Tqs}seeef$V5{t+HrT4`={_$GJQYm> z9Iq47DWHj-AC`K0O4EX8u5J^f)n(X+Qnlc|J`)v-k0mMc(u_6Wo5bGepW<~bM`YSJ z72+Os9u_F+%C5k>$E?)`xxH+pe(=CXpA)#}GnOv#A=B(u^N^eu$53DcfIxC5kXqlJ zJZT6rqTYtGSjTm^FJK6|7`rC0&-T$w%pD;^2ne-VJRXw|=iIeeA8pYay2s6%bEfK& ziXTr$d5_Dj^(PI5r|6}#dQdzzlDhL&T^%ufE4lcuSTMrT$CAi%U{_%Nehaw&%6P>Q zF4#U9uy|qtQk4N7;6lg~pK6!5s4D{>$$>zCVt|ppzgR{*f&f?dytOqGgtA8;@uNv$ zK17p+z`kq_SL>rghzTp|l@u(D{~!?9984iyZnfA- zcr1PYK-WCcUn;WYfahuO##MLg`+jXs@4`usoHk3wY(Ix&)4l*QnQ=D*DfUvkSkhPe zXqcYIr{Z#$Uq-E-HO12T`JBNDOp{solUhBC;%Mj%m+hV7gz8FwDGx^x%I*hnIx7`z z?pGs5um5#75OnbO>IKN?!ZEO>a1@9*#IWuMiY;add zimWpC)@OxZDJEhH)9=kXft7NikRiNiEgW7&|9-Z5D3B~PRe*za?Jv6!V-5R+HAQU8 zf^P|_Uu!-{aAv@kz^fcyYMqYpYE_kE(4@~Xp!TpDbbDS}kM>u_>UpzDp`x>u(#Q&u``D}HhUqVhPJiBKs0-cu6L{yp8gDk0C) z^^%5kDr-(9G3rA(8ne}NK`(R6=Oi$5bbIF;MW?RivyOse`)^^NWa=St1evjT{W2xn z?k*v`Kk(Qo)6i5^EGmJe_Y76r zf&B(E9^1W^nIpiUz3Dg8#zAXxOagYOIjyc3h8fbX;e@HLR-z8&H=>Sw`TD6yKR zK;Q=d&_N0n}>>E6EDJLCVp(W{yvVkfkws zZr07|Q9MKPVNGTX@^S?uC4CrEGlZ{Lpt1ps*0xW9#5uKL1y`Na||9#;B zUOb~HaB-b9kTJR3UI3~6wE47@PU3Z~!E!!VVn_kFkm)mqQtbCGI3Q&sp8h5U>Y+&; zvg>Q`<5!H4Lbu~0g%}r9VrPaPbe^{U6!YDYm1vG%l}Y(2HYyWiIG~*KG^-?E;+^>x z;Vmhw*<^A56bmHD*`P|d$7@!>cCJ}m`-9j}28&R?jCPq@?FDAlFr*J?AT$&5^wNB6wKatR{gPl(nfkuSm`4Y9w*zw}MP+r$xOPv9YY774KDsuO=T1 zkw^4f^LY9$nGo|2oqKNLLQokxPB`=`^(?ycSgFC$2XP|zVMb3*y~lpsqdZUoP$_&}E*dG0dof$$k+#?1>&rCe`u$nu9 zhmoPPDCuH?Q)thQbe9b9Wm??-JK|ts;vfa#@|#dg)%LleJt9AKtvT8=M4%zXE)tbm%s6s+6`3I%d3$-1^y0?>S``yWdqw~ni%0A*qo=y>A%KIL(gaDg$o?Shwj$peq zRD82ha9MJKeAW1_W4A-NZTYVcqIZT+fKkUlLmE?%5~M~9)~Uo}KyhP|0}xL}TNZL$ z795yb^jVoC9I{pstJ@2wF~YVpZCWVjj&8(HEgQAb@*LMc-#ZLn#f0U0)XJ;9m#Y0KrRgxxMdrh5ca1%Qo@?gti!)B#VS$IpT2n4OXYQoii~Hq!7^h{w{ecXN27Tc=j@?D3^4 zkfYR;KQ5j7Y51#QYD5>u4vpWR`=npuxp$T(O{9vUocow`MVl4X`3SU6<6ybw^psWv z`}UCsTF-)(9*8a_qXhN{_eXCt;_XcW8MxOm>s)jZhjLI9dj!c{gE6J%C)}*xE9yT z7cM9$YDid>$1BZ^)h{Tl(zHB7-}Rp^3Y50d zJmw9n?JMeiGY!RM)eE$|x+k?oxThr$5L$C~1@@i-#!jI9K^&KcMm@{(7plRo z;|z22{m-n=L@swfl0H!U&PB4k{A;eh*^U+>xi)%yd6K~KylwlGe|0z`rKuCx$OZfN zRE-*IQuuCO%!Bf%$>r>9NE76~siQ9*gUYERD9ZEw)6{v%;9~3 z!GSiMk#lx+|D+Z(hEYf`Yx@K?7aW^zn{l$3t`NO}%|DROws zG5eu6QZwb1>Zx}Xuxe9SF7;W0j?vA+AeU(_5E#yb*$jXGJCqGzC#F8fV?>9Lk}|VA zghb3I#p>Vr_l5mMT(%xuIDvEBC)`(T&*gpu`M&^(v;LxNIiaHYa71bG?z9V$WX^SQ9Aw+uKSBIT z6)-{A)f(*=;`qpqxb_{Mu-eAA&FQaw(M^M6@FSQNQlNOOwsCC>xXsNUzMA$!c?!-u zmKN4TN_1_7MRXHezO9`w*>}lOdk(p(SXuYm3e(2 zE`#N-+fC$q<*~G*I!c420tV{X#Trk5EDLyGM=kP|?(8IP z`A(ta9DrptE!BX)xo4JM0khgj|Aefpdw8X(iw3(+Bn#8WcR1+{t&5F2D3|uSik-6s zQ-Ak19y-3C;MXwH1z0DK{)t~zf5po(LpQoDLA8m>`(?iHi}Lz85i}G@z|&k&=iuUJ zdt2Wky2N3Dp0gK1*ZsA5Xjv>9WC050<TSN04U(elc4Xr({-x`ZfeW;zciT?G~g9M7YaWiklsMM!G_EjYqB#Z5=Q7G zZgVU^6RRGXpnCa1GNfLtBy}V;Bq_`g(zgCkRrT#1KTb zZvQj;(*F0+RTxkpj2V9;XOqxny~i!@k9|V*VQx>0*cY!AUWU2V3Nr%fu)}GPu(pbs z`nT*FBakRF$rBsCZ_5Kt$AYXnxs*2=1P<{5&SKM@fAez5Dmlhpw~c$L45AFD;)(7* zLq*#Ry!(JX=T=+*SaQYR>?vyCJ`7=+`8&CSFw__QJo6c8Y;+Xwo0TWU=&x~z$R%-m z%gC+-M-tL~-SvA9w(L=nMbe~{$Zff*y7@(X8@@cp_`Z8SsmCY(W|KAPAKQCx0ZxVl zAH;i%;lk|RUL9xyVVLehis^O$@hL65(`*oojj5#Qh&vTUOHB;Hlb{orAg6F}2|y&% zN-XB-WB+s3>@#t_ev``2f07frp%Rbd-|6>Pkej^>@5J?~I9KYl)!jn5rA6f5lnTKU zK=FZuiDewd$t}96oT-jgUdBpH=|GXIN$7L`UWX#%TD!?qU$TVtLu7?01r3vHwx^Fx zQ%%L3w>JtRaIHE14E&P=z}4AuI$!+@9IJjSUD?sl(3l@OwY?7`IC|zyB1)|_fid-c zp~~rokP=YqPkfwCQ`KcmPn;lTieU`o^b+yL-Tel^>qy25FZT71OZ9!pRo~;_f$SrELR;)46Xn^b> zFN0Z*b|V=PnpsjwOkGlvpTPY|-}4By_6@^Ild~m*YFM+JsQO!srK=g`D!>)&DNbwm zlVBTP`q7O)gwbbEepd1Oa}^3ASv+vJ`O+V_B#amCbJs0|`d&^>F0M3o5kE&{ST{d5 zzuCJ-rI32?uV>i$w|Or!pV-Y)T&kdp%Q8edU;`-amuHNC$c|c=&fZ#yk3uJk6uyvy z;DMyiM(87TcK>+f+n=%MmfQ!qauI55nBiar<-g? z2Bf-Z&tqAWf@pzlmU*rF{@*zJ_#UZ_))Cm{O0XwK1DV9x7gQk~ULTgm`%nsS7Gp+z zIY8>J*Fk?W{oZ%C)>0_Lyq;BY{om)+-1Ft8RWwBjlCa0J5C~u)ww`^Cuw9rDkpZ}f zz(?#a*q1Z8z~$+~C_9c$($7)60NvZ0zta^S*B{B)JPX55Gwqwd^raZ{jnb(RP3#}7 zJ*m{W8^IOBs~{u&b|HtusE~+fuy|^NYRJeoZx;5qiW(eXuqbYwi10suShs;r8 zU|@j%vZKC_N*;lgPa6#UKsqXA6`zhDSIch*sgx%mj6WRKvA;8%I&K&pak~2*^sud6 z3s+el>Rn&@-{w;E7zRyasb2Y}n$O!A8bDyG7qB+vWn*Mw*HOC9dHN%2Uoi(jAZsjp zL&iov&q>aNnb_E?9Fe2`qIbi$Rv8cA&-==PJQ2K#Mnazl|D0YMV=@#nc)wcDgz))1 zB*TOeeCj9pc&$uo#^+qmYSs1evkVn7^P?$Ns5!?6PmrO2va#+5TJ`bftkhJ^$&|oJ z!H$l`k2|DZs+=a0XtXo%N!1!kEhWo*7vL-%-yiR(8)vJJ@@BOEv$?01LM_#P7dDUJcxk ziMqB=?SBrg{o`QxlbLTxNK(rfv`F6WkDmKDO(D7n)pWutdS~r}X{r+1-hn08L+@49 zq#qD8>C$H;bc-L^i4>81s|l0&oVb*bOv2=t;CP9|ZK=8iWkSs(izNBbWQm zmCd9WJ@iJEYA!HmK=Q|z5;&w+S8l%4dB5;}k4o?tFZMOn`Ez)zGWPY7rn6l+v}>~f z>jMi3lRU64&#LImlkKjo!<6G(W>89Om_Sr`FM2^`M`4)62w*V)XlpWriU>@OuseJi z9px`x-@16YaxUX2Irhktwe#Fqg$pKI+^ zq*|}>zI3z+vvGht{WSXT{K8d1#s2Xg7hbG~Z8Q7x+2lhOLWa;b&BNx^d9Ow#BQ(e~ zu)WK%lA~c@_+_IQ`B=$;k<=1IP07&CjS*-HvcRnkrcx{eq$#4mac-b3=fb*v&%~p{ zVzHtFhK#B!%G z_7lzbCEUI&IO~Y3nyG+5-hur~Mya%kjBs;&aTp_HC>&Hy+u@)u)YnC9&7zFM0bza@ z4x%mwEmsYD($<)SJ8DN$8DWVWWh?^2*}i+Bwt$uyPoJ8(LyPpCfLiM_hvG1(w)g~E(RP(vYh^So5O($q7WK|6gAc$L zhyVbm+wLvIU$U=m#O1$Y-FT>&u^|%!Y_0Jx4Wa?Xl-kHBByE(~QE_=RN2>#)k`Js zh$pnRuB8&{oEg$fzXM3m0@{3x6k8<{&S63$Q zy79b8>h9y~m!blqN6W0z)|EW?JltZHn*@jqq}#!{);O&fqN)3k37!&EqFB5?FYl=8 z=<=OzKWYEnP+c?xQu)olMo@J3bg_0E-8I zf{Mz2fz7)Y!#@~v)rqKz)Rdn}dkhO#5?mA@gJ>!u69_a%@^bja_jE=psxIQEe;8Ce zKX#s4JR3tlB~E4cT4pcOTfjQk#ZCfaClFh7#$rz*#J2fq-|>wBbMK1jiQ_NPrFpuQ zbWw8^NYcioo#?3leZ*G1N0$&GUK)ZyoRTMcs5(6IW9PUM&LnWd^UkmKP*X3`jeeloVHEw3pR$S^UJn#4wOl~;c22H18 z(611-1YLK@gMpKX7jhORDNx4yKF_>|2d}Da5jI5rfAQ!qfzD%L5J<7?b2f?p01=Je z#;|6AU;kpajCBKDpwu=u4RqJ}_s3v8Br>LjGTF#$;y2$Ch6%#^4@dnFjh`I4#(JfL ziD0#>o9eYjDX zLEu|7PnyYYLFddnHE@ZWdU242{A;FpQM|B6sWT}_U+T*_h|bjRvWP~*klJ|-1k8V1 zNX|*?NG4G3HWt|VqE-QODLb(+iSs2(k-Tun>tN0NqvsfKC~)pRHVwl*e-u%>N~2{P zOOO<0<)fkpJ?su0i&kEU*98^JFXuz&fKTTDxuDAqMvXweAtq>pSyWkb$j#eY+FoD* zp55h6EQf(*1asM8-#1f<5Q3``(~2vTsq8v$00503E$2l8eF@xP-*@Z#`823-_Avyl zIojqR&|oxr5KHrUq@~qC0pETL&48=yE!~&+?*Qr)h1-9Q8#}AGNNw`vV2YU-!h#nH z=onRKutQyI3?-<$L{&HB`D@4J{J;caImBJtDo)=?rGx*=Cy>~b)A#fS2Ib3m$o*w1 zG!kWEpA(eHmIbiXAA|HDp4Gc)$u{JA7Q46cpV9Mg(h|~&CE;+{L>UzFB!5Jlv%_8J znK*%9St-erwF~k6vEpv1{qKSE<`|6~!4P4zP&w_MCnhUM3WOlmHy0;&5aqhF7Gd`) z9pVX(GU~+wqd)SrT#W_>kvqR-VcP74+k;Wp z3##y?+{|`vpPT)B9>2M4#74Z~)wetuG)73&^NF>%*DO_=-O8fn>{ez^bzV5+ub$W< zYV*bv@6U82k1B!JFPnQjCBnkP$HV_(fvFs&==FSQ8HVkUHCGw|VeV@#=)X~#J(s{hIntf;Vjp@@IN#}$u zDLq|8`4$rfq~BQEr=5LPLh7Jm3@SSTU?Nr3$ zK1#_Ep8@p)ywLpa$p@XH6@AivZn45scaldtH3wt@LMf~g)FzczUrsV_wksi!-Mp&X zwvCcJo`~3(3|NJ)FPK1s5?5G1f6fqvOALBVZU^4_0hJyO$e1OUQ(l;H+I<>J9Q?ra zW;`!6z6n0mZ}_Ur4}t0LuH5K4*)Cg&dPC87>#E&5@E7z%WhUF}t&uM`w)WBRj{z#! zKMN$xY)!Z9*;T_z_X8wUf86S#QT~UJoY$b9$5=j6mV@LXZx@C-7P~su22Rd6;6G${ zdm}r8DDeDpNj$TMteStyhG8W!!#WHuHb4>~NE{NQYCjVA8(2TUM!!AACjMLX;ScrQ zLZ6EhBx-|2c-`b7XptIiuzb^K(2t-gTu;sNhn$~XNZlHvL#ErlT~rA$8s@%VSizZc z!z;vFKdG_#GaN3*>D_QA7Z)F?Ip6dsy|Meyos=^-GU0a+M%HTiojzi#UMi|z#(wtw zbG?~M`;kT3?6-GAZP+DcD-JkPHjOd&I;jCt@`#eFu~T0>I}ZjF+p@k>{d&cFWXM8p8N5T6?u z6axdq4E;^s_x04TBGWrFw0KHXWzp(uNjwCj_ILLgpb9N3DiDqQ&iWg#Wx|6!!`rQg z7nW}3{nqjiOx?_8?%#_A{cR!+)qz7=y4$%=4!tTlt5cHnNK30ZwDDZEqOm8_xt-U= zHOc6Y3N5YV%T37MkU?oUI-fFIel~obVxjb6*$9=yy_2@H=-8lUJ~;DS{jh7kf5FJ zYwshrR+C)^S5x1U0GvEID`Hnt^ja&ORu9z0s+2z4rry?iV0w79Ev#s*@%K>ave40N zzZT?k4l*)+nxPP=skPl3wG?7;br%I*RcL?+XyO>~)e8&qZH3yxzH_p@#M%i_Qr{QX zPF=y&wIo>zWY!^2TGD+J^Z$5wXiw&r=~;ocBlU|VM50e(e$*jw^V45R;3WT(WAf|Jane%vnA(#1_oJ{3G3=Q(GCv-kS5op3PF2gR3v}- z0QGY0dZ(x_fECUEvFGI&ganepKJ??}a|r)nid`}>F$o&r;(bZ8x)xqZaKeCK^`!@- zR-8iBmC@X*(Bi75mp*f21!m;oK?(qq!ym?g=0!h&+~>8B8@;(1vl{KElAx^ZKrJ*% zgq(3$Q(QklO--_XtDn-K8V6hDPTB?ajC@x7^i?G$C?Gy>&HJc(TIV}grNIxm$BM%S z7zEykm5{GAHGOqH7&*!*gV)4U#g=9^p~Fz}SxzNS!F6imd9fIkG2^d|7rK@)E)P-eIOqRa z+^r4Tv^<{A>xKB!ATY?sh)WF!CEcCF;&1e#A?-cqAViUm)BXsGxKneuJ#f8)1g&4& zKBW%c#rlWen3QH?^jKHh)TAZ8hf$Ag8E-r|I7`JWx8*-Gb07PeOs3pEA5M)%k>+-A z(rBU=yG)HDTHn8)Dgn@tgfsCBHsW{RwAS%h(!O5f7tH=8 zg5A7D57DRFwai%lZDEGFY}3*@?;DckSd59q%frLrwY?@zn6Y6hz0WBDR2chsRNZN@ z{?-{~u;k*&a7vPb4|!K$Ue9LT^^bKm_c$16)P>D78*>Hu`BQwa58>U`z@iJNdfQT} z4?`^IdM*5!9Kfa?G}u)CeRhrD%{N~q)E96^|6*4w#K1?&vTQz<5lW`Y>o!`J&>Y|u z3!l%5&3-nJ>$U&ttq0rlyfP{S5BGF*S&`s{>E~w=6JB5xVBi+m()RYYFmdh#bYwl{dJD)7YHe3E6Vo*{Wsb#mW|D>L)~V@ zRu9kGOMizyG5vZ~JD>L5k_V~0kf@^Ccu}3}eU^M28}lEyc%>UX2G-VkoGF~!Jc7pB zZ4I-w;%$OpyR~bQAbYyJmrzSTb855v34|LInblvZ zvxh}>NH#vbn}h54=gdaQb@Sj&{Z$z$2Q5hdLJ~gsw=shPeK3%uo}I%RI-v=MRea{; zK#O2- zSo{Qqolq6(FuGRhXZbEWA$37W#4mI0Telw50g$L;WzRn_r$7&hS}yzFf4VWeGG!jg zh#T4uW=)Q=2hPl6Kr%gaZ>K?>mG7>dc~(Nd8KoNtMt0q)@|pWPqRRA}l8>@uh_n?PamybX;~s; z5`$b#0=?X00juh)ceD3;2N+4QgJDaeocZ8^%&Y5Ps1XJMT~C}pz>(7N%)9cgg6orn zwc&yRzJOv)3SjYLJ8D7y7>{J;%LrP&ow>4m;^4i~Hg^s-mlYgH0e9;Di7%9NP#X$? z0Zr%D8>HjB8`DeC<7C{KI#W14S8i@XYIp4ak!EJGywL=DusZpsLFqgvzde{G8q*u;keIm73*aMz-8a&2KJ1 zi1Rp3?49j9e_`L{^@37Q&}s1ZBp7sdfg=-qr*`;fr2o3QajvX}7$OkyPdAbRlWhaE zD0P7goy{n9ta zHqEjM^6_fDN&+Ck-}1XQx%ZO;%=f{8tm#nh!ALr^Y|_`m%a%AZDbNpql$7GCW-FC* zS7j@$|J!~z7-)a-qboYLDkaa*tzkY*tp^}FoezKH-k53juh>oiyH*$~hqc&EOfH_G zc{si(d8oZxQZlhVV5(0>xFp7NkX-6NAgjujKQyu3!hk1kaKuUX|B1F^-`q*oK+ zG$#xwshYMrJtj3K_KD+@Zx)xB7I(7U{>r|bws8vFuCXG2fV#v zeSNobOGFc4Q9lZRo>209>8%?;8Pn#nDDxC#)>a%IDFnsr$)Gljpot72QBlmz0~(eh z0j_;nCndM5`v@hqOWMBD?_1WC%wn9VwqxO zrr1@W>hzjM2d_eq%_a?u!$=v#%c&36E5X!^Z+N$F%N7*l+h|)5;_D?$`vZgx*#(AS zD|wI5bxPp37I!JYvK8Qdfnj(T5>>%#8skZzeA7(!CZB(cx(4|StIfN3+bFzrGxqpN z?hOnB2Thqc>c8zOUlsUI)v>b1l7D9lkHWy!oN^z3(Ab-Z!73dXA-!rx_IW1`3-Vh8 zDZv<@`A1DJDNLBA!110i3fR~Hrz{Gr3@7Ij&iEt%(jb8Bulm?&tqt#B@o`-F-A(oc zU~dL?gk)ySXmy???C&}P6VGi2Fs0`|=1-~@d=a6h_^cmD{RZFwx)27otjF$~fmZoc zCN%smCKG=hcf=pbc)S>C^Xlv;R-6Wa?5#m=Y5 zf>@FRX9!tk`C!VzRquQzuN;43Q**?qqVRvYFm?7MyJVmmquy?ZZx_Gkp&s?1dC?e) zpHwI2Vly!}`9+aM-MEcz6$b(on_F|lMf(CJ4&)V>_4jA#M@>bsTE4_1w3M=Yp4KmL zKRT_pT9r_QFj<_<8S+NDlA2KDXt|H%MK;_iWbVRN&Vtkx<0BJ>b20^mg;@8{sg5ie zTXBH9c|=4f;{f?@%H{!=#hJq)6idnmShe{8sQuUsOX?xERG7QNogqpU4AQ%Q2ODT~ zfg2DI>bNqqBPXrbcCMUiqK4}?Z>6fdbTXWva%t4c8Nk=71wePg_2~}R4NP^A6);~X zSZt^EdO?Dz2=3O^XDt24dnF)$49X&Hl}qEkiQ(i@Hl7G%?ZxK}UN>W9=9%?8#iT0D z0#c8XtMae7tZ)^7Mh=T(6#6^44T{<^>%{ZRXMTW#EX;?_HkObJ zXNKJ?iFLp&X?kkqdtK4$F815XwWxf?zFy$_xhPt2C;Ml$qiuQ6Lb9w7xz;$5AasFv z_Gt&vu(ytYifTA3-8(btv%3AH$CE^UMxtVw>gOZ3UvZJ;ecYU4^_!kBj1v^4vJnnC zB*m2P(EiYZ4#9SY6gWrlC7s)vQJC8jy9v6?f7qy03lxDJM9@hPW?@3w!Ehka?^O;> zV~@@oBYp&6LDsC+lE7S*p!20AUV{Is(Pe`*pct-icKt8vlR=x&CJ&^zyy;Gp{yt}Y z80RuAMgJwfHDBb$dbLT=niIt7MlwB@EBqlOQBj1=5GE!DS2zolgiy$a`#l|D)jrBH z*diWy%H9Y#c&#AVg5LKwz+@SU-1tnK+URkKNsNn?@6nf4*1Y{4+KLs@9_9;JP;sF) zgA3c#6tDf6r->QU@E`FqfqyTgh>jXpQg;q(j|a`M6m*Nl(98QLW_asGTqH(+NWx*) zKJkjfLxtf{%^PRlr-<|WWL$c;1*-b?OZ*YE$Djq zbU-QCjTRSEz*1!Ii~a2(@Te&4=L6>yN$&E3cUTP-ae zI^>$dlm>fE{ha?GfXN=cMn^6#0tt7Iv$_9!)R4(iN#;IFZi|67P5tG0{Y%x0QE@O6 zl~nxYpWypMMNn82oV>bA#nDkhmzbewSgP6em5=>7xvL}*a#2LMywbn$dH`H17F+50 zz2$>-?_c)p)3Z$dBqn?qw&WrWga9JOQPXQ8uwL_fYB1RHeb%(?L$yAB&Iw9WQxFPe zhA>kk%st_V8JcMkxhX&2DW?Z4_nd3xFfXyS?@;UqKQ~=Q+kJ3=ifP08 zlGXoRIOJ8jm)959DbCflqvtnl7uS42J~b5D7sIP!6kfQK zl9!)C)$quLo>j^e^yn4m7`x6v{gTo@Yw&UMWW6ZnC|LBk1HJJn;2?Mu z>~wvBSlt`^-o}YJZIl}SI=q<9;G*!nLR_4D7enT&G%4D6^{wSru35~HkrxR zSwQ=ybAqZhNvY&0rlngD>^1;%H2nBQSL>L$DYa1hL0+6_-B;)9Fs>=fV*5T?4F-jX zCFkT9t3bOwB{`-qo`|>r)jmHmak&t*N+_*NnGSX~vvk z#tvmEEVu*+S-!U5cKX1k9A*yE+etHx^^dISF&V$|uha-9G))~8)aE$G+V{C0TG2R; zZOUZ2?L9DjFI_;PoLt1z9sFUZA7swUl^9kDCLxQ8g`9p zdenbf0AA-MSZf=MTE&P5m(u9go)4xu7ula(KvO;!Odp*OBXQ2QLk7W7LP=kY_ZY4?UF)A&O5{T%PV}1Vq#2-a zB*D|?pKspw&3>w?f5GR8f*zG_pWe-%oJeO%@w!ajUI!H?6uLXMeMYfbw z4yqaNy|SOr?dP&ZlyI#M#B4kcu%SM$oW&itH10NWF!OrXj8bRNGVt&y9lqU}xVVq9 zWP^57{V!{%4d27)6cy#M??u!H56KEV4Sus zNZ#j55Y;dDYIX*A{|eNE?bS@%;5CguF)NX`E%`COUYJr>vPPmHjKWWM?r=hKM9P%W z0jD%W+~)dX>&|9$EWzY}>y)?p`EbY^dh+HB#*HGsJ>w^p4P#7!=9wwt9os22f?P+7{ z*>Z1bmcEeWbkpZRB}YyDwwjcicY2h}n_GZem67x5IAP2n;LhFTysM8gG4VTto(q=6u@<-y>zhNB?YMN!TUu`B6X;d#@(5B(E~Zj5OB zZ(QYfp550ZPQIn<6D6X1banuiA}~+xawHzZQmUD(jc#pL!avkvcKxB2w?88j1$EbX z%&N>mit^5Y6okWp0kaxOZ|Jf;H{M)bwfE#1xV$sjvx=_N@7~-G=7w|qkTPWpr~IbS zGY(`af^F&MZn}gH7Ctt~XAZAGC=0(^r$$qr=;UM6XP6lSmj6fc5E8!W8DN3Z2H)H` zcwLs|V9=F1b>7Z_(F5;q4}6pH4`>zP^>?|2hq45cGs}g%5f)(^qtq#I)up~ zKOpEH_LRn5mDi|~x7n^Gu^3S~f7<#`zAAFpvL%shtBKc!npvDgoP=oOAUZ&W^QgL+ zlPA;17{rS@YM%SMh%OsY$Mdwh&#pEsIn}dU+Tn25UnR5dT292t7`)%f$>$9U`<)_{ zh)h!%U}s8$$OeyghsJn-I1Bd%Q`Ub8E)*mV`k*0bY`_WTi(!Ccvl2Y>(cBHHA<3n$ z@CS-&#E{bkJEnz!M^^IKrORR$ZkS?C>vapS16hrgM{K-hzMZ**9FJE3M9wO4ZrTUz zW^DIodg&a*{1Ps96Y0Qdcv~}{=E!eInr$gOkqi2PB!rLA%H$>Y2ZMtbwjMlK-?Wte z_mKt@lj;%7EwK}}4#757c?!mGy9u|O_|L|7I9`*@ozcGh<)0d#0s)j-ugWGjkZ5x; zy(}kjSJz5|Gp8+p=-21Yb75!#;u{DNqw zw8i1@6{GxDz~KGGA+OA@uwxUi>^$4wbNx4{D$AJICfS71&=o0SoNflnH(KhyyZn~^ zA^J2PxTr!?xHc^WgB#ZfnK!c(LjC*^Ni2APN2ddB{}iHk?gs~JKUvScqY`5BQDViS zi;+F02~G7?2J{@-FZz4z^UAJ0fQj9vkDsAwl_Kk*VxD8HeqE(=vkZ7pXWb?lcRuVt zVOH|8>m1gz`+@GW%!88sCIM0R#d-C$x?-cx@!iO6Le!gz{DP$*=*YGE<41rgMo7cX zq<2jqNkEvIlGK)Pbi?d`S@4+Z^@x0r!D&Iy7zy9p$=~x1a9z2UITTb;CJh`WU)SCK zyXngg-iz}xz$@;kszOt>@11V1{}cAC<1MlbSoZz&6hHWbD%Ep=;NyR+Fi4U)Xf)(w z5!ikxwJnWD0|RzAwV-%l*6KIs38;jjE2b8ws&zt#-!I?Qz79Sz zJ>nuv!Qv9+`TkrzhN?mn?Pk1AksI%>qs@URevk90<0!}DNTp@6QV;Glz)RCqI47n_ zN=k$4D+C?*B*a474=GIXDuCuZ{BADLV?M|{jEPQ|@)Jt^yqNd_B|6scR+_~%Oo?b3 z!;jWiB!=-;ZTVbP#moLj;-+_v7lSp$#$O$w8qvXmvzArfN8+o#0f?NVfYl>QT=SUd z?a4`13*mpN-IvN8m}V>+Fe8WoC9zroH@^1GezimDg5Gu9^xYxyaL)lkEW33;E=^HP zm#m;3=2R72cX)l$QR|0KRUQF!&LJT=GZ?qZ&(Wyh;&J+F+k%&p+ve{QgDEt0$MMn( zYU5KV%IXPSf_6F;T4qBMlLxF%UowWAy|kafq*qt5?+z;Qnm1gmYe#&JUuIC|oBS{+ z{75-*(38ObV636S@cf^*^lk|aoA@hr_~}QUEn{s|CbNvhX`AQX#HWza0f{&41TZZ- zz4ofzh~FP3@Z&7qIVjnnDz=_3vuQJZDEj1omW4K3cXP_Bv>QvwJbCVV zni%^w`p49%+UfrVm=I_083-hWmFXJk@2%_luRNLF{nOuwhlSTKyWNW0G!>U68`dO9 zwLR1PLgRPpZ8cx`R%Cx)BTPut0eWqdM(LdHQTivWX7*3(|v^+H+Pl;gQ+ z$CHLM`Lc*$(JW+x#q1eZ#Vsq{H2D0kySLxs^kt0CvwG@zbf=5Z+jnqp|Nh2r?fS5c z$Uy|Dc}`uW{5~u62{wNmkSi3{j||O*jVx2DZNr;)@A>A=Z@|}|c^5u9Mqs@vv8p^b zC$JAmg3!6?$V2{b1!hwJbD+msfA{83o!(57uyQ0kfbaWZ#qBeyu3F;zeiSfwCkPb= zhc3GA9e0}i1lC$b6}n$+e`0XoKpP2?tc`yGKSK{|qXEc)-F^Fx?o0hLFFOWc;4_Ha zK_Wkc$f-scl>5e`4L@vNfA9-d!aiX-u3%uAd6!ndC$S=V!&HM61k>?c%R0c16AcvF zjP)7V-&MB$=I~>B#7l)Im5CLHFpe?N>gWXvE?angAj?#4+WyAjndS|9zrE{PStR>~ zCFBV>GZ2V`?@ltt;l-aYQ#pVfS0(n0C!@-8`?Gr+@89`JZ^-pS3Z@u+L6AIUdq$jH zcW?WvzE`@Qh$JIc=x8%dS>$+!7ZRj6beJd<48xkWRTi<*2U0()|JUu;IsMrT2@-e< z&!?#nZa6zPx4yFHzMX#~Ea5O%oI<|<0+B5EneLd>?2#fVqS7sTN!)tiySwiLpa4IC z_LOx4P<7Rk+W^?`?|kQ;mI3cm1KQ7Jy)F^i1miAdXgLyGM@p)^-0NPp0=# zV)%vEF8v%Nq;a$#mY1j1#w$swhdqv6l=xtLMe<5Lk|Y0{&rzX_=4|V44?jHm*(@Nu zyJ~*&V3_MEH^<|MhthW@K3sNo#kn|vpvSpH7?gY9q22%9{&e%bc6me@_XL7+gFx{G zS1h`X4Fo91U^T3*sD;7Z{lDqn+P(vM5)7XVf><;;xNo4j_vOx?s95gP9|TxuHL9*& za=TEXpbQb`E3T8cG$b9x_Tpo4`%Wjcmm*!jh6ALy&=K1hNl z&u%g4JN!)A*`AE4Ut9Z$)(y>HvC3=(>GN%3?reD>$WA@bF^YZGUXZwF7 zr1VdGQ%tYmu208TB$v&*r0U~&If0Fzv$gH1=7*+AetlB#*|N~`q6U?tVs}c)KZmhY(I2xeb<}+2SkZS1rMXI5%5k6bQtW1!Q z8?pPvj&13Ksb^)8%Ed$HGl&qW#UDN+KZD3y0b^*;_OIK$w(r&MN9<%|S}xfnd^Yp4 zvnrNXTsZ%udN?~aUEK0MdpqUMOY?!Ttt&FQ%_uRMh zLyh0sd5>|m3hMQ?U-&|>XFL*Q9sbFIpCA6o{@d)NE#CG#py3!{7%jETBCESMx7}B_ zX6uCmuk`FBOGp%+btAm%`v+l*curCih7l1dM%sp6?%CGy6oA|t+>=BYpPN27c%-+k z>j@RJkjqL&!Df2Vb?>;{E{gzw$rlDYs%Qf5FH?3-!!fg zn>$`)Cg9VoF*-IL`M$ITkUMdz(8C@oE?RI)*{X_~JPaa*)pxedA^{DW9v%~vTVjn} z0P~g~G8!zodHLr^iC7RWFfm&INIUFOwKe~xcw2NM!i~Z#A>6R^{IR3HzrOdkDnFUC zVa=$za`7F3Oof)AAxLEVGTGPGzI@y6ukW}#(=}3Wm)nw*pGGK{yfe{<9L5NeBC?Ya zmFXVY@%mTlRyW+cW388V34+4A+Y@;^hn9f0r<YfJMMh(;;!GfZ4btXIq4G;BTR`UN7{yt z?*3Zc=U-b}dv1Sy?-N3atX*cS;AF|8CSK0ZfHX#$AZ(vK7(eHxov@XNNGH|M|D*bQ zYR_r7cgH`CbPf-a4D$2Mqz?DlqL5>*h0mPv1~anAb4I5@AQY2JNNjqUn81(kM*!iE^|9Q)MyvL8;w5oVTBLFF+Fqp;$ zH_V7$ef8qMwGxqP@{R(`_m zN!rR7;|N38*tX?10FWee_TyN0Sq6aOd9me1su)GG*fMO!D}M%{R7|1ik;ZT7jD{7l zXndgXY=9(5cgWQpPqln=sbQhOL zAxVa87!zca5$TbKUIL($VlIGM>m}j;kMI5OwqL*bKb5PiuCDrv#h)&Fck(J(q~dJA#?=&tGr%`E9#;tYgv*jF`IZft{t-t?DXfSs*%&>;yWTJzlk9)BXYatPJ)bCBRr%%W zYnFVf^5O*_Ev_v7b0VRPXJ8!dQ@;6kSI4=eknfx+NDz?{BrNGCEp}b#0dH^nrk3>` z8;|@j)!28)Xib)FNdza{!m@l$4U)vJc3~;2zc?9%gtR6*O|WK)BqIP?Mvl$5sp=BsAr*$wR>ba$lou^9jwvJt zuYigILi(>^!(h*N*`7>O>TqgX?<<{8_0@JgKhi$jZ?p$Gmt+tMpr!DeSxO0|6tvEp zr_tAq+C}lalA5BjVh9$K=|+HIthj}y$M-$p^khaX+nSTuPZGqQ`Me6Dy)iOz z&_HPiJzLt>ySp)TwV1cVM`BV$xQuh0K!c>+3pQ(4|s>N2ak)l z3jQuCWX1fCExpSwwJTY(H?CzEgh>C6?ibP<`+ppkIY-h-FtnJb7XH~iXKjfsik-=t z;SPIk2fkE)w$qyF9Tz`F061}7<=-XFt^5ExCVQMLetRglFWcV!=-~&~eSclY`X|@Z zrcIj|0LseBuwumuY}>XCLWr4N{QTPab&XL)`3+TTN|wbJPuAWfD3aRKw==b&`#(^K zKU%&LeTOJL+DfIEXRnN1T5;imE24{H7ud;2ja6!ugGi7dgl7%*O^X1)UfNB$gHF@Y zH&U+-H1uxH>`8AyFSm{>V=m~jLN7mmB1%|PW1Sa&ck;@_%JNI13!^Jl%vvalWCRQd zb*x#9$B;@wXS_^yAlshl8g3julzwS(w(7Z^zpUFuATJn3;bJs<3QR(l-&}oX>GH%` z#x>6PF)A!!ryBbj2cGJB01jNt-JGM~PqFOFD}P;ne&zLM#4{{N9QelTtA}?DZiAr0 zOoxFkEWV)h(#oqAehMDD`F#xwsGK13hCL^JaG;;0=oo4~+T^uo51^9|0~jUjmQqTs zHEz1;CVlb67vrwG?t*1ma9tOdUV16EZ{H5bai*$2fx(mz3+Lam2?MHu#eAJV(fuN1m%Aya>yS(}qtJL;4$zxhsDTVINZu-j2Yk%;Czr6Eh zZM2X=PVONZE`I*?6XC%J9>o6r`%oN}$4hNc{N-R1N)4)lCLEezd!V-X(WNZ<@ zE`taV#d2cT5WtRs(F{^FdDk((u7QlrJrj~pLb!vDiLXqaQFZ0w_Z!!6&$f=*)SmwR z!WI()rwv1AG|0-a$s~Y0U??+8w|?eU)bFb2hK2@2A`wKRQOut|9|sQ}gk@Q?a_2%s zFrFcu(JUe&=TtFS16eF{_sKNN3~0EZGoCiC0Z9QXZox`K6452GMPSS?UKmRjEhsJq z15PUI3^g6?BSk$UM}}IR0jJOD&oX-k!V-|hN|NvSsONIr2jP=~h=k5+ZL*r#Fl0=@ zPDCQb)zK;yx2lCD%S+Ep#Ni)cxu>CPux9BQeIvcYZTo8Xc3^}Z&KNAKS%yQ+O=J~Y zzLOfm%&Itn&M;XGvUB%8y9OdMGyOlU>PbO**yWPtiDcE)ORfR4J-D}jU)LYnU!HWn z<$W1oMzpV&X3Cw+FpT5tVG;xgg0KaIB_t6MGiw5II-SOQ-}_!HUAh$O*RMxB9*66? zSiXEYnwy*9c?EF(lRpnxGsw!(4~Ah}165+35)w%wA}{R%l6bPQB80$DdI+EW$IoKb z*{fiTfe@4HMw(>chh(DpP%tmuETNG2rF@Lw;qE;jPNj^YY=sW3LfS)5<9#!YLfRXc zozv9VvZvDqV}F;IU49NhICdJGZk+@G!2r_~*-~Z+h3q6x75Q{XG;PLxb|P(5i2giL z@VI9-@K)E2Kmq!W!G=KjQ=pAs{Iu3{;&pWzQ@~GR5*Yh^PWO|2r4V8?esr`LKkNC) zzWVJ0pF7A`cA?SBN4Se+9u53IB{eJ@V`T~cJ^FtP4*J%YCy?e(uD=TffkyA6XK237 zH9~Tfu@<{O)&7=p%4}bgp8QVZG*06*PUAHGU*rD(d-&A3xz9(w00000NkvXXu0mjf Dl;TUh literal 0 HcmV?d00001 diff --git a/doc/images/axa-small.png b/doc/images/axa-small.png new file mode 100644 index 0000000000000000000000000000000000000000..f774cfc10853b0a33c225c711b7715dc6444dbbc GIT binary patch literal 11616 zcmV-mEuYefP) zaB^>EX>4U6ba`-PAZ2)IW&i+q+P$1>avZyIg#Tj|cL{humV;-6ci`pwOZ1SUBub)W zhmYcj)NrO7KxJiR7SJ^NzyIr)|NQ4a&bnqzOr_?Qv*o|oV)LDEs(t=@zB(K4-}gTs zUq5r-zi!@t;CU(VH9Y^!`}O|L>*?DE%Ip1ZeE+&B^L3s2y3p$fZx;-@v*(-j8szIj z!9TCN*S}5k&o@1Pn9l3}b)Jv=_WbC_f5*aNti<(#H+}~f?ETw6cM6Omw9xfBfAfmh zE1vIt$oG}<^ZJ?hrTb(7-@p4|{p@_b9=skYlm)D}}6 z#t#0`oM&Z^i*C8@j@$R^bdw=QzkT7W_tS^_S`EcFKW|W?{pGvD1P_H{P@DZ{`u}g=LqJ`+4{a#jLQpO3`3FA-@J=}xbs@m zbl2CJ=ZpREKLs`r!E~3oGQn=gbBU4MH*AF)=U|iJtvnmfFG=CC6-ijDW#THdJQ$!RC6u0)>eD-EkI+* zwA@Oot+n2{X^)+2b>823VfYb79BJfHMjdVR34CUpY35mGoo)8zS6EPhS$UOJS6h9% z4U~4=Y3E&b-EH>+)=oI_q?1oM^|aI9qxOpG_fh*Xa=(w7dqvIPKToO5S}F!n60CvR`WFPF1x_RAN2-UgF3H+$#x zhtV)MRx6*_!*jp0TCr#6>`oenaGgHJ0fL*wV?FFlJ*v_W-aG7+?t1+^kf2d8hOv`l zYHL%Go;$tI<*pZFN#(<$ye-MJS`UrC^O4sGj9z#f1XFy5!63@);jUGEZdFcBGo03H z0d^;c-S@}EFn<@DUxucdv*)Pg1mN{CFuz$5uJs-h)xT#PAp02#_xV#`eu~S-MKRx2 z#r*;y?*LJ7ei<9{b8vpOD&e;fF<%VNC#&+^qD1otoctJ^U&N*r^P4sKtE4es2+o(1 z#{8T$|CTi7=dAg+q%l8d&8JCI%lv&v%(yWLE!&pk+$+r;)=a*%Or3(Td)+?s?_629 zATm8Jvowq9Qf72o;-0o)+*<{4`*r}tZv9ZqcyS%is9^7wH}!@QW~|FLR~65m*uHsscbhTk<($h+$oHji5lvj zw-lmWMyZ|Jn(L0WM~Y9~b=p2G)n3|+f=OZ9yS3zZTB&YZnUC+M*D;OHC2YG06W8sw ztZ`#iN9wg8=^0oHDUUWsV?w=Vx6>SPmpNC+gO5~!&cw(R`DVIBxjNmi3#NF+_CxV> z0!~E}P-#~UeJieM2n{T&raD4P$4uu>3cD0n2C{?d+WlBDP$^YvyZz!<^Tjs;2qjlU z4zX;6m3qGzDR^-}m+B~wb;LK-$E<{aj}$UtJpNj_v-u_}S_5DVTk2@Tq6ZycxB-H-AE08A5-Z9uml zaF1kbCozH@041t(gEPV&4GEXZk&kTQ=Q6_Ijmy>*SClI^{n4i?jyZ?LuazdyEY>2DO279 zs{393PVyKQaKX9anG3An;wsJ#y)k^P|0yx9$d3{#5qWI&t=a(pyRRQF6j9}+4PT*? z5a6zg=_JL%vW=8KsL>o#TM>nLC)WT@_I~u9tWm|urH58Bm#M7-gGNe6%<^0J}_5*_!Q8^eS#5b|^Op8Xzk zEW_Vvk*<4I(6+g(YS9!sD$KLuqR|f6vKQwE>AitpSTBZ*V{ffi4Psx05}slBFl%xl+Ym zP*&wU&>7@DAMez;vY@s^bq}C7=#4R7%)H-5w(7e_Uwh{nCA&)-vv`E|g*MlJQlET;d&EdNF%=Hm-i z3skVR0rqvR$3>;xkiEyAP8QcjO$hT@>q%Bnr-b0_sS!9xJaj&pk*yAOjL^{}L-g8r zuhUn`xF6HbSttz@8dxOHGbRcK>D3bO2nekb;GC{|SOOiul5i1G6>A|n5JrdvGzXA{ z`xz^ouvUr)G~D@|=&T|xGYf^g#bm9cWZhibv#_DUgnXA(m>xiE!gl>D#f~_d5;Q^qV6~!QRAFss*B_O!I@#jC*67v+3pK8e*Xjb&^ z>22Zq)~de<`&-n_@8zqM(gy+Skgr$0;(7LF=64qSa})kMw`%|XYm+Vn0qY{bpfJcL zAMM7VP@j>h9n(TiAzGVEPOE)Jby$~F9^8%iJf1u#b>>yb3^YMA)XS3LeAb5=+h&2v zp%(0fdRR?`q2Z%S2so9y;Q{=D^u&&+2Up1Ug*Qj-SaG0;O6FSx+lnJY&rtS30HH^HI%|m_f5{#B;rN`HtQ?{u$ITm5 z)X*Bh)oB5!6&6>I3u07I!coVz@9MZDNSO|vDGB|$GgG?_1B)SwtOBcxovvcRLv4u| z92^?52Ddy`m;?Kw-G_->g@-`IL>p@)rlXsix(jLnBa4T&i<&lq-s(_9n6yE9%!g^R z|1s}8hR6u9%6(g3J}Gs31N@^?`h&nR-xQ7-&eF4C%jy4(q7^#L|AnFzv6}l*0%>+} zlQ4awBxmw6H0|xc2+ro{b}S&UMeN9_ZB@uQja8wNGsA*gOw*Q~lSqkJzJmj}ndOp@ zj<#6k$ux_Mc-la`k)JSCXHzBi2^}z(alDqGPqG^QxuEJ&ou7qwMkxzx=*w5fpcf)v5^3h=Uw?DjT?O_56{M(CfE^MIa8$EvP5 zJyypx5WY-@Oy_Q$YMM?PCsRCR_Umc~@h*Q3|EWetneVt^2`V@L<}*~yNWw4+Z*Yqc z)krQ+BAYs<21t7~(t_)^%ZQyjDX519Qh9wqovbiZf|3us1LZG+!PM4;^0ld_giy=z zxB-u@OaKw&ql8V4Xy(j_0I6ET&!mPbh29TqfOot6NDcIGi3hJvg9U$Kqi}gd5~kLPxQI42u9G~0=|LjP-#ww)NU}h1vWyV;+u2aWrB-n6 zBZS<6AIYG`)EH^s=<-WsyaRnDMjTOkDMNgRD}sQUqsJ?{w6pnxs5j3*zg7 zM_0fZ`m+M-H~K=KWeRdPfYu(=096z)>chD|gaV=F@lcx$7`^+N^S6U*Q3GvR4cE^@ zy*P#?z+jxk)ON1d5Ppq^kQc2oNp>ej`nTHx4a<{v_8) z9jrdAIGUi#Iy&+YP*DxMT4YvaNZA;}gEP{+LG)pw=|&Ck5mJe&8``NAO#nPM09s&e z?Knl{a$H9VP*s{>!CedYgP$Z+%dMfwAmCJ_t8KFRDl%)`N}JoFmrZBjK)?AbRjO&vDxen5GjTP9x+&`C8U{aAOii% z@BuNU#cj(A2yr`bgy`p73Hv7D53Z9kT@CPLst7+oJV=|C9WHYr9QRVE`6zTvB4t3~ zlpKW00go2)t4+nfw>{@Eao7{J2aS6HNT)mGtU#m_D@)9ND(YbY&bzQkUdNpMmV9WC z!`46QY%=DM-S&BJzRAuhOK#^fxO36|3aGlpXm{^Hbk zBlvirk!&tph$#^z&B6^eg01nmc1>rt=5{7qp5UBI;ycW!|LnNANTKwsM1j3 zWz@XPC(fr__2s9^kXgx6|Cmb6@&-fz$VIA2SL3I8*bMmzHq(e({pbkX35WIfda@DZ z>IT`{mp%_gg%Wlkjp)^Kpc+pEKep3Co1tQ5uh3FQvH>{75t=RP#X>XH8Dwk35QB-V zZY)Cv%{bV^n2Q=D$imFE(ZtwMc!I*_DVALuLx|=^jHx4-V^VKAMyOUNAxyT2Yu2E^ zf;w`8!Z)>Bmfn`%B|H?h1(H}JePBRBiz>f#>f%r}!F<&oADH0Ep-eMS)X0!Xeo-rO!l29Y z_7sq#Ho{19E)^-z-~fuz(3UwKay{Es)VV+<9|psMn8t2}L_?o2{F2 zx%fEnj|`z(%5l2O4oE7WoTlA`EWc|04}r0B&Tot(2zYr}lsC;zFn!Bxun zieE9>iG;v(>IB7!jlu}oue3LA(fj6A%x}y}@B@#79ds2TtAk@Isr(a13NetRqA+cV zyW9*h+1hx7VuvpQB?aTc-nLU%8kL**NfVzAL#ce;jW3|%3;7z5G+70q7#I<}lv|-N zS5k)%(;~G*_ZwO2ty&hWyQJTe>^a(sCK(A*(P-nwW`%`};$6Trhv~oqY8V__byHfQ z#FCtp)PoF4_*k*sVMOH0AlZkyuFZ9eU=IjsSS$*vur^RY`Sp|~n>z&-xs7U1@XRO) zvdC((MmVIc08NdkvgeFgCU~NW^BSP8AVndRd7`#|u*Ibf>>GuH2_@BQa!*s+*btQ0 z5owbohQgAxi=B}>xMbG8zX}`zY(MS|kt}(NZBv0 z0MMXD)FdZ$`0Swy0~CX*4*UAPj0P1jC8QZ(JVehOq$o4QQ&5!W zFc3FNj`Ss9!8r$Ir;7O`LCBAo(VM-(h7&(9Fq@T{3_(y@q&@(f{|!V{yAc zfdfr};i287xF{+;GjZ@Pw8|r>vb+vnAZJ)!5$29YGmMBUO5r1oCTe@QoE#=c2i$ZJ zfoo{bdh3`&r)nhXr^lwlnW(}bEr?%3HURpgR!y5L1vLqFR}Xp$NR7(Cnv*|d2^_=S zS^pbRSrv=$BE>SQmYT=^RJrQpPh?U=UuXo13J_NaV6Chz?=f)_1xF>3PvnAiCs^%B z7~+ED9Tw3R*HQh4HW&kTj8x!rRL~;W6n|Nf#K5w*2*u#ihE0;BfY)G7lFG3Y)hOFh zg&skVeA3P~Jd*DQs*-!LN@+7#tudt{#Gx(LnI;xO9RhiI!D-+A7g7y6f^+~BMHXmi zi-v;ev~?@XB^fxdk{H@Wav`wlF_cCWr5Fq5q5K6nJP;wR7j;LOCV5!7r|obYQyhd}r!IffmK-(-~ z{o+>LFZOzG4KwRW(!c4?e|omW{OWK?0PMFZmcDO_A;^nVrZnes?2e? zc66>9truwA4p49Jgj5pCYiJ8IuH8EaNgl6SL{?CAC>{oa($+u9l-dc%vUX}1b<<`a zl?O`-5pNW(lAVa?mpUcwBc`Z0)EE@l@q;sFWOJjW-VSMEtc;3+ zNCntr56a0=a|n($1vfwwAf5P+b%D8%E?{y55;ZUC3Csl{rpg^Wf;Vm{Hn}D;n99UF zkuLWAbaf&S{~HD5rNmMZ;ns!$nalUe%n0Zl z2pR~7Ne{=nV=U}IIK0}i zj47)fPgK<`g7ut1TvVW85rq}Xg%W+B1L$W9l;u-igHl84ZCK_7 z<$;IU)L=Alu6tc7+&a(agBOkqcA#`JnFOYm%$mCn_Q0AdUH5qYOc}RxPg3&1fKK#a zh)9S5F+L=nN5E%7x-vq_^A;o+Fr?UVJw3?-*o|d2FdU=6)Az_-p+S{;y22WxxZ=RNW*RYSXoyBUcd_&&f0+@#)4HY!fIye z-MF-|(+ZP8cY#%VjLh0lqc0(5yUw@pdUGgMa(0JLGx@VB}w*g38yd)ip&W(oP@2HG~ zbwi%BViPbuyK|l?c^J82sSyKux(=EMMnZYX3Qp^1v?59Chl({7MGa);ptHw&kyGU0 zK8Pbcq_6~|NTGrj+!1RIHahIG)OF~Ng4uu*q+EPzu@VTcP#t&|lL+(U%3(8#jT*%& zeZ7ojNI&0$di;@m%o@zl4k2P-Ar`s>!vt4)8i;zeYKLLm0NLOaNpGVNlC~l1T+uo6 zyTeqnBGTazGS0>OJfY65oOWBzQ!*4Cw?Qflf?}UAYKOVC2MK#aOsuoUa;L5Cw>B^& zS$)eJX|7eSw6ArV6{>xgs)$;VPaD*sJtIS{>1xZ+E2ikA|Qc96PN z8!G7#57};X>Q;}N9`ztuH^g~`vTNx&DX9H3hGw-@gJ(fvh3(tjwA1z>ZrP(GHX6%d zu)-k56aXFNiRz651abkU^xFMUKLneR?=FG1q8)}nPAEnTVs$1_`yzUZ4+ZJf+C80& zciO?l&$R^^Ix>u_SFqYD;T7hH-(ZH-eI z(JOgrN_y(hfzmO8cC%INty+bd3Puuz1%&3*Q>xIbo}dlqMFK0#`Tu_c0p5+5(@$(e zq5#7HBLZ9awI~xCHWX$%H5(MPwvx$TNvN?XEfm-)^VBTJjugyT3f2^@G8b1si=#gE zJW-x?wQ>6uR!<$GKu8T3p8SWLthRTw@lTmldmRze5-L<}%A}2=`WWbHi1EdY& zF=!M$sX??I+_ED)B*)4K^=yU?)azMnHB>P0PI?x=$l74#>eil)mCswI90zQX6MD)R z&3Q0w11aIVwtdySW)Z~;5FL@!o))t$N2jV$^4P(uWD2I9eD6>7K#+m4zyc+8#v(wW zKrxra(1BJPHdys>FogZKfrsi_?@0S#NVBb`#nT}U4$Tv#K|x9s>7&@Xo;8FhpHOL^ zOJ_a_3XXYF3~gK~7%gW8wA#huo)H^Jmf5Qb z`jF?i`;mS;4KnrOV0TPA$(yv1V5shtQ|%<|`_>8lshfyoUHlCtc%57*+H~=#g=@2P zrij4uE#wpQ%vX6Hs)$3UO(el?gJa|UARKBAcgfE~Ot9QhNuCT`TEd}s;6pE^q284l z4Hk_NpztxTCrepZ9bQ3;LKo`R=m~-tY%0m(E3c4o{dU%{HpuQnIUEK_rL8#i6PFqy z61)Hz28g7{$?6WTq>gcLt)@#c*bc@C5-4ReXzw^S>7$AuMff4$*Xj=8zu^hZzY;#8 z0=Mh1P)+yrmmb+l+7nh4(>|QNylzMW9aa=-StXRblkV2ov^A-<*clltJ)mwV&G%4pec! zbS~1&qGQcE7KFYohy|=G@TD@$C;bv?6(Z5xvL4RD0r@5KLiyBXZXcuJX$ve&kZNW6 zYxW5E(04l%BA4TcxDpC>(R2J1J8+HE9m&E(p+1(xjTz9{UM(dGMUOKP3_$8-jt&Ev*HlwQs2n7gfukA$K*YS>2<|&t3-&f|ZQJ zwZneSpggs+wri)}Mm$D8S zc^6wINM~XHw7&*@sOjmZ+IL7ztWCZRZ&)IA5QJ;X_bv+)t~1(t!Z)i^do-a`; zg;6T(hN%E0dNQ-s7YPmk1R%*@5rki8?M&sWR-%0Pz6&JBV~aL>ajpA(hpsRy_+@8 zYluf}U9aPz+BDi}NJgl0QCZe9A6W|F>PQOEAZ}nGSFR>H)lLm#@X(k`>MS-8e)EHG zb=YmkqOCPO;}vdpxckh%*h{TS$@8I|8-oqUylD5ZaV1t6n%X+-b#&Y#=+t3LwikSs zUpfGQf1qSzp~vOSnw|%lO|mxv?X3(A^bl^ zs_V0s@!F?4sdjB0AOKKi`fV7b7h{)-89E_z)Wy`X?0j6+k#0mvBYf{NCK+3EOL!0M z)2GV9BRH;SPbxb3;@)SU)Dde()3y?g&ycj#aEPxd8a19Uk2+pViPo7uFi*NhJ=E;; z?cz{MP%HSJ&5V;CThNT#cQZEMPkAQtUXSAHR9n)45S>#e@z8sFA)9~#(oNlR)}`qI zZ#}C~b&N5UsC_-sM!$uEKWo)z`#XC7!yf&co`0!%tD{RVHE+Ih%J_j})m&24%;Y_3*r?lI`Fq8^}8e}TUd&g=S9fc!icZ&ZSceFK8|n`He2-0yfYzr4)+b2@*FlKEwr z=9hTEz5&zxnizf`COvrZ?J}dAza!Q6AbpEb3+M~NpLPC#?^K@L)bEU_%Su>un?L-{ z2&+?qGpG%7s-9tfCi+$&);2@GvEu&zB>jG{nqR*^$$LHc{QCV#3i$o~NfZ=RB{p`; z{iCDM6;v&j{21!6HU9$uv@Vs=D|l7_00D$)LqkwWLqi~Na&Km7Y-Iodc$|HaJ4nM& z6o&t%ic%^L78G&FP@OD@FC4WBMW_&Jg;pI*Zu){I4M~cNqu^R_@Ud8RaB0TpPFT|f9A{GP3qpBVR& z!f~MU#c@7{fv#PkQFEN{W5;Qn0KsSAN^kiqbztU`^jb@c9s#}Ez{Pb-Q}%$%9bn+e zkWJZ@f>c6254@kzH)VmoTOhRN&8>Nk(+40;vr66o2Zz8&fwI?q-re0kw}0<6=l26# zByyg*1Lo}j000SaNLh0L03p=?03p=@1e~;;00007bV*G`2jdC?3K|J`+td#L00vY^ zL_t(Y$E{a=a8y?vJ-_>2vYXvaNU|H6lCL(233LNBq0mr^85odG8m$V>FfGK2;#ga> zU!&71O4??ml_^rI3}dw|Z5=BuY5`$HZ7rfr_zEO8k=9U32$-b1&Btc5`}W=AAG-?? zmi*D~oj3Dl?!5cX``vfW`JF2z3eQ2$EReQJh(y{mquf=LjNBofmA zLO@7JU1V(CK$|5Qo2J(WEr$U~mO@DbQXB*d) z_x9w^**dL=P!Mafn$l4jm1saB>qh#UhC2V1eeK5lYMUlSi_QQgo!dwdf<@txHG}N| zhx>7FU7zVB1QabA7k^n=2qGwsg}*=8=7~*g^UeEPmRG0%effiF05fde66mZ5o!#x7 zz0Ff95~DDs&8m;V_JrwtXXL#lfv%41@(sB&Ap%NbcP|^}M};B<0%-~jtO~RbIdWTl zHv}>=5)ec>E2D@|s-ocs2M;-|$j^N94>@v##?-6SN2Va4ZyoQQAb(vj|Jj^E3Mlr} zi_9y!0_o3J1s(!8IlwY6oQ4qU#Cd#R#Q>w$D;*`X{xMPAKV~!(h(P*+>g1?`*r1V; zjph|}Dxlu}YeO2C*Tldt=ENWa_ zC&13V?O%H25NHrOWyexia-i1c(m7QnHZ{n)*`==^cpHRJFS)sb(64M>T3qP;!@k4A zBjKeDwID&;KdYFw02-3IH>L(#alQNQFMnnv6dnsrK;W+~d+D{F@mLJN>{+GH+;?5W z-0I^!r=S1rvHQQ#NNo=VhY3|yTv%T+Rg6)W*tYGz%I%+@Ne~p78caZdfc?V8 z#pN@6Z?wJr&~vAcoj5}Xz=bznbMWwg2q>{5vDjmq_pWQ6=dYY`tmpKaO-J8+tCtWV zZ>%i?2#_+RT@fNsTfSDcsNNr*u)2GCH?O@mCtHya!0)f<=sHUPG_+=ZuGzG-HfmW% zI(jy)sq}db8w4=3qD1K^V5Iy(5*Ul**KdU<1H$I{IX{1VrPE;mN&z+?@kInh2Zux6 z95(?#!jVXp!^tEOB$;+7XxY5$ksW6zn>_%GMgwhk#{&XIq%_NV?0Zd)%#7A8zk9Ro z9Fv4<{^E73mp2HJH_g-4+wU)*0nq7mY~T6!7hXBRgf+=cbEj{3;0vOqVL^rL=woV8 z698yiFa5Z&$d~)-zV;2j>8D&Iuj@~zRxfXW$na?Rb6|fqHXF**>=evDt9{ndJ9{g4OD)u=E1xik~ z>)WelJ2MRcSq|rYD=W2VilP9*fvD>2J_8YNUiR8s3R5aT3H4WQ@!hdv;m}Cro`-jZ zv;&B@BMqTHH@4xy&k_V^koT-!W+)S|4PwbnOgxhKm%B|xSA=Or!rLQ zL$^_(b<^(;_Gt*gL_9WV|LTdxf;{)jyAK9OLj<9WKfZBnJYs9nbFaUqFjGSS8Wfx4 zbFaMZ7}sT(TGvQl}YfcyW_(f8l+pZ+;WOpg0%nLn%No&F=o&wO*s(I4IsC@apj zZ9clK>#h}jh0~`IQd90a-#u}GN~eOdedj@^DFWrM^qlStUrPPpSR`-3)9_@08W2ec zF6qZ4ArvLB6Gs9eC=?=oOHhe`2o;oOT$q;}Q{vQ5WJtzOGRYJT5ol5~ e3QT}(m-#{_00009a7bBm001G? z001G?0R)`1od5s;2XskIMF-*u4htO;dKdTJ001BWNkl@T@03X$xsL)%}gWR>eH5fsP8#pt0;;CPx^_9qU}Wav;d{?sn9CVb_5wL z1w<4m1Y|BLps0b$%IRj_A#OyRv-eu7|JWz)y;-6pl35h5?)k+pet9GBjW}`ES<|<^ zwHCQ4;0_^ztC0ZC4iEtj08Rj4Fb4w=>>X$GFqi`X6Ce)2;EZzd*)r514Bq~IaVK)P z!O2;&TO>0WfdLR04h9$`AUHt6W^6$D5w!c$nE(O|uU=V?LB{< z&j7^W1al~NFGcI5PK@LbfB_)@&}q~I z?fvNBbNlY(p>uAOx5H9^$%{5pl_vYl4{S95ZspL`r|z7~)2NM#D6w&%{_vyI!O_WN zYM5sb!5yKzkpWJ~5Tj_OaCHD~AOqmV{5pf&(M833$Jsa(PEH2SfRllooEYTHqN!!V zc{F~&@eloP`>R7oKizNF8Ky4ey@tcxicc*Z`qL2|IdI8V9hU?OfDnqg7r+gnitY?e zBn-FH#aE&--W9m*9cSTin&lI;oo^Tbo!T&jI9)%ub=T8Y!j)@x?n_%iGl(Ez>a@zBx6XT#ILikK%mHR*RRuDL zs#y1@K{;+i9=#u3qwWh3Q zd4AzG9)74={D-`7XT2}gLrQBt%S)*Ubq6lvIo6>HLWcjvU;Iu^FnV> z)J_2chi!q9gJH#${u8@DI|V1&6UEtu2`v*53=|>@x09B#uZSMnxc#XI;{~_gcCl`$ z=jEwMWKxocDl})Nf3f;c6vXF5WRX0dP2xL(G|W0(mdlH@L*W#TGm5 zMZae+N-CHv15HzLuT24J{eGy9{o@nzk!z1{9P%wnjkwgY(xj&X-@bhQUrzb<;lATC zOwv>LiitGuPLqPc9JCD|%_Pdcqo6bM>!<(L-f@d3J~B)$ zSwFg99!1&=HH;~&&F_0(fAM|NTVL3}C6x#vtHFqsQ%&ylb~}HXg~H@qkgdJrUyOaYi!o2PXWdU^GhfyfU);b| zFE=@u&bRBArnM)|`-8{FSK4VGLIxvl)JOe=n}-K)+@f3J!fF^!DM(RBY(*;oS%USi!@NtKPfB+GlpKPypko;gb^uM}HI4v;x304R615&*HKB%$h0 z)~=Wyf0(UZnT{__j`fJMXJ%4>;=-|N<&T>Fdsi-8?=6F3=Z%@*oLa4%98l`2-C?GQ zyY4GL{r++WtRN4*=S@EgJcNHHmSpeDUMXc8oD2Ah+=d|ioX&5yt~`FnqgHTHyT0Hp zGQdR46Y1S03m>lbe_WQ24E8BkpAjc;+WE||CyBGFs^$zh*bL-G5~< z2BTTs)UnIRamgmd8OZKwK8>%Qu0C~9Eb-1#yVAFj&u~_rVLR+(U;|*z=1^k+0&Hn?^r^gY?b?YY+c1Hf$5?Alo1y;X z@UH!YUApggdU(@HbCWOU-`i}8xpyOyTozKA$XVKJ`PgIkv`c*Du&j={c-{v3qi zfEtU?8I;KOgnE|cn365^B-S4@2n&Wrxqc&073voJ5A=Dd~ z|LEMIH>~5t@`Y*j)=`h!y9ISjK9=t^sXfSV@3M9y@&tu^ZDqfjEEoK#2z<_)vvF)Bv&%~mhu8gym z&K=;WboV+4MiKy+<*ui2IvIk_Z`ZEF=0is7`^IbYq}pV4c%cp#s8MTzJ>gMQ<7L$+?3`DnLtzshM$Udm){LNZV~M*aP&> z%E4ZO6rh+wvKbn%%a<0nN6e-wQhAeV4vW*xi;ms)w0*rp>qpKb?F;3s5~>mC?l5=j z;J{m_;gj>{txM09NZ8FlQX0s41|D^=gWsc9Voy{5wNVkloXiox27|J}05PK(vC;q! zUF#UZEr+Vd8KkK<8K1wg`i+Zn&8a&Ny5_7+&Q;x;YP;57{@BW;e{w3{yu3IW>@yYi z+@+z)06Vdwkms2TQw(RKrF%j>Lxc!`GYGxjILRjw&KRJK?TSj)*U}-!YW;pi(f&f71q0o~k zgEEl2Qz&A^9!z*v(izNfd9zuD%G-I4aLjEcAs_=vP&b3e#b$cV+U;MnfQycs4!q&-cWs2_kSOqje8=@4Mu$x8&A^^N$Yp zjmxwLXSk3B-ofijuc_@g&)wY;dzyOI5GCL4q?kbgLhdF^rkM&eo999^y=Zju>d}!0 z^ZLcjW|cD$sVba9h^w;jj{e2(@Nh)tB=iMI)c_&qU<_F1SxoIbo2$f5h3Ev@X`a{< z>RIrr+zm@8V0x*&5CPfTnUQlQpeM4pzJArIJD<8NSFVre$0Jp(dNJqJ)b%my-?i`J zH;?>|rSsZ)p;d=4fwrmMM#VQ9IoJixM%^T|Z4lr0H1+JLDLLE-OkNzT1!X5>&Y*-2 zALIxhOxGUdo6lT{R~`MtTss;FCsU+c^}=oL|FC@FUv1Hmg_Ti^jkY27b5dqFWhYqa zsVYgZ2^j-{9Y!FSTL9UacQNhx)X(ZvY3BKux`Ii$Xc?r;Ff%aEX?xM??T?nxrE7;T z(yjSG({@6xE&POslS>!8J=Gr_?mHIhhM1ycZ9zoByJ3#<4<{%?>GZDA&i>n;rhfhs zRyD?7FD0_9sEc*c#ACDRV_Ak zkRaM*4wbQ}S-~)TSCPPr9Y(v$;PPzJr=JPFzbDi)RM*bi`SyHOB0wzJJYivS>WYoS zkE!j#jiZB{svtRg;~)-4dHAWN%igx`N0%2T{h?AqN+Hx)%ln#xX}gueJ7fd@3~TW8 zo}4|QejfO*QwDiSZpdJqQ6f@eNHw>${;<*Ee?MmzoLuj>xfeoClR(}Qty6t?{=lDY z(nl6A*sKQb4h8~LK_u=Bgzf39yKTdcXFToy*w3sCwh3}~PN9%gZ3XVHT zW|{4Z{I4j0?+Nv+SYMlQj%LFg+}KUpd@tUktl5~7n-`4r~ZpUFW~pc`rX8J zA&oCwz5VI?sw*~*9Ga}n6D((9P8`<5@c7aNf3fKwnLD^z&wGg38QNtkMcv7}%?9V_ z`tND#XAOO4QD_G{Sa3TqU3-EZzixAMAdSS0k|L|rab$d~KKR}|_le~Lcg`;~9AE}d zs0T>hq`8s}k)KIR_@1Vo!%M`SP|14MTrfIy<@(V_^7PR9(VjK{+?}gxv(>f1@-0gj z{GHXG9xP5{?G!+aD9BSnSB2T0y@P9Y$%q|HdZ5|7=GdpcVKJUR-5S8% zJu`chDR|slx-~5Q!Djn`edn)*SW+M|qMQjz$bgl>00!|Hl)q%K!r`b!XkZi~%3OsX@!w%ufYh6vG~ zP|vPQ1v3%3W?Inp^6~NOXyc0U>V-C9%9@PT2cb^R8*}^K#B+bQa^OUN5mdpL+`7^V z8vv+mBk}|#r)}Elo>1>~xe_6_bC3hg#)CKSzHWgIY@O`4lLnci2w}{8G%o$)@ZkF} zfBW42hH8LkgF6c}e-FtObk}$Oy(iRr*`;Rlp4md)x@dIltLOO2<99CRM2JXlQ}?c1U?t@JNE*m&tiUS4#Vl_ z^7T(YeMK(WxciWfDpev%ZRY;)cXm8BQ&ZNm zxx@{}*1~xo=^uLcgjeU5x2hiJ0E5d8hoP`KX8_@LAl$WEox3GJ;r=F3+LiL}-~eQB z3xO>=Gt?X+NbN$q{;>61zh)(_tgVYFp@OTrwc>boU++VCoaGeTu82p%FY zjMuN+IP%!4xqNzJv0aB#zh95XV}$52-W}#{M}PU_i|60ugSA7K zjVNk1R5jcnu9of#U-14I{gShHO$`R1%sn^FkVw{4d3w;>3)hZ6WnM2iamSLTsL(v= z_p03Jv0nA=eV4y(5{}F*kIWG&CNZV%-9pEE*%Ru$Y=LKi%q-mPCjK#yQw3hu$z|(@ zpRyb;ICkt{8lwS=Wb!dE87v&0JMhL+`NNC**Q;KFQ&|^cv-$Y55294M;6bY4EHrK7PhQzP82{y zh?WiafyHU}Yw%&G_~jt@9P-kMBaaA^2RFwq*gO_Z0XDdjY!a`{?|Xmm!23*&R&yh8 zs3fc$=*i={Cv0baArz@XgzwFl-M^PR86wtW$28 zycOQ+FWlB!eskNudEwx+?t?rlAOR|XfLz@OXWACA7yX_CL}|OVp-XS9wswEJ{=oLs zQft9A&inHdlTY@RPW0zDEr%EZg0mBEk2NcJ?|osx*8MV2 zn(c+1l{|USj{eQst^a<(E;(`JVu4S!1qVhpZzUe?t-NNl`pCssovH^UF0oE+3+@9b z+u^Awy(mi1^o0(W?uR~=;Q(*LutrD&-Fjf3_Uj}vZBnKEs17D;rCO8TM|%fvpWAtu3m1<^=O` zJv=_Q|He|hI;Z1sD+ z@i>E|3iaC9j!&BP0~f7H&qV--1T@Cd5sgl3HN4NPYQO*a)Dj<1X17cZ4qoXr3uA~QHua5itbVwi!P z+}x}ZHZyfX86WEArFa9tX~ren-a(7)$-uOYXOZptu3wZwbqc}T1m+MqHw{Z=W=z4M zP6?VB;%Z>x$d*#z;OZiETY5lY!i1axA!%Y}cOoKzsHzC%Y(eUjQm8`fiojq&4&9d*;05er8>SkjPQ(Z=h zWoD1rz6k~xLC{R=w6Xvh-SP!uIp1=Lch1o9ku8T||f|S6qF#j-NTy@<~Pi)BEfHcDZ76%-e z84lzGu9``+z+0bq*NgfL2Gk*hlo12k1PG>t%OCMVi7Rl8y8fzPdG1#|Von?tK{Cqx zR$`<=XmD%+U-=(@_0uPP>S}B>1D0no6n4lfyVUfHQ>YF|K+eo;nn4JmGBZz;KCLYZ zD#bu%)2F!1=v0P7f*TBY;H3+>aaAeZ~6sSs96RpSXdm`Y_Mf`(4&5AHD8#pV44@> zbU99Xj$+Ui47j_2!7(VH*46TYI3dqm=1v*DYlPb7E6tYqebHjS&X!CdZWN?)GYMcG zS7H1OU-O^}aAexp*C&Y!`w7OUCQmEUVv6Bu#SV(sV{is zwGSp80gl-=Eywn@cVlf9s`iKf%Qx&BppIaU$s7pkrhxZw-BX^wiKRy7o7O66Lq2-b z+kP#U$}56{oe9DQ&t+5&cj4@JJ^Nc=6Tr3!ZH4c2k6jh_m<0O8N#^Fj6adccB=n>C z7Poi))(tO=4pV1nLTW8Rpaw8Pg{-CwBd9hP?Z?o-Nriig`VBXJ+z4(kbKzL#%`>P> z-HgQA;&*=I**$F5*0O<+$-8cFAc&*RTj%xj4}IHLEC&!^;EHyu|@s&TfLM_FR z)>#bcMYNQfk8w7&Z9nMa#xKn_~~Ds7&0J2 zrp_+J(Nsm?3I+`-R6s>r;yT+cP_)6=+vuYK9sj!ZQ;jW{IZsp z5h}NFjgvqBli#=0Lq9+r0Adu{rCxymgDIH1HRF|!#k-YJv(+@S2kp+1&+TPD*P3e1 zEvoF4G*5<6dwK1{Z-3=H5V~>5l^_BLY_qL{IyzyeSi9{&~BMt8Un37pLoCC9N@+`CJz zevh5jUo=83Bdb}Ny`okoM?d(E8wLQ`%ySy7EWQGzY#^9wnb{Wz3`F3%FYj|&XLwfD zxb%+qZ8Q!d-@PS<;SO~u0hQy;Z+MC6wCeR>24*->9u08v-G6<3ABf24TrP-_GEIRm zd(;n%+7-uqB&mICNNewY^Go}N2uSIa0th6;PF z(;8aG|Nc$Cl=MJd9YW@iH*3A)L+|{R%1~3eEj7a8_9R|o7z6AecSfjTERHI9Pch_F zu|acA$=%(OyE%6666cPZ`b8^L1_x&D)2OoJFL}{7R*D`WCne_WIXN336jhB8w%bk$ z)tw-?i2@0t@T@nOYAz!iU}qzQ!;XD;oh&1r>ilf5E*IG$NH6r!ZV)m zfIbjmo~=Z!$vHD_z5_RX^zKCS(>B8y*Q_Xlry^T#tnLZt|VXez~&LZuEp5M{+1|M=H> zzD9Xk@pOPYZ~U_t_6g!j8G$P&w7}G{mht3g{m5Ea8N1d4uj-BY-slO}T{r;5fq=5c z%^}&46e$-#OPpfbHDARTt$HT>)s5E&+GyJi2>}kDwjz}#CFg&dzR(xrAP<5CpVYqb z4{v{E&(ITqpjyO6Q$<5~#8>`j-EKVDUJ6{=TLb`N&{{Cx2k98hO`;!q{)_W0tR#mp zwQUQqFqMd;5=v}W^pt7M>z+6UDIglCu_R-JD zaldgNk=wKy?btis{7drzku0kL0Fx0Vz!_ii@TY5qM95tO2Qff|l$979#ew-bRJ;`^ zyDK=L?!gc_l5;0l@^cLno?D>`nu=C`{D;12DWDG{xWZ8(mvCS3Z9nj;Raf(Udx7p5P6`wjaT1-J}>XIpJCF?`*({6wk`On_#F zB}8G%Ex}^!@DMUogn)4nv($hYNzLFsTR?r!TK#=PsHcVHJHRG7=id#gD3>k;j8^LP zZ++%N1Ih+EC%7UT&_Yg5@R#qnc|uDs`|Y=lt%Q=hhSbBDAuJB67(e^tUmbmuw9PqK zwfy^k@+k>0RS{CC1Hz)V(=gZ%`13#b@BiN`e|DKMpjiczkPSu{C;Z||ZyftV%kva_ z?Cmg4Uh|tTUW$kcA$9L&2%2#QKD~Ry zaA33z#Zo`t=_-fOIaSloS*Y8cMP(jPhr>BjXU`nu@JwzIR`W#K*u6A!Qd&DTuJ1__Tq0Cg_1e0hQh&?C%A%jW8`DTAMQ$*)ZvERw-a zp`2UARObB|gB(|1vhY=3df7n11PIhY1QBH5WX4N>@2|F~Ru&_MPYw;X9`$b!N&%N> zdmmQ@HWMEIjOVTU{&84-#j8FHK#ai-_5$h}jK~V;&?7wQt1c0w03(KYc1;R3HZKrG zSq$g*t>IXF56q~%NJL%67Gq!%8^^r*(Z9c8fgsjUfoBx}rWs23iYNW($X8m6t%nwa z^4Bv8#P)RY?gs**Q7wK*(j(-|oGI+w3DU&ut_G0vtIdW>a9SIMw1S zp8TVmc^+Ztc~W8A_sw_T_{(Tfoqj!9!$!iFKIXeO(wtSx8kc|d<-gUmU_R{(i^bxC z8d|CV_GxZLk&xS_>enRPsezvk|K1nC2L~H94>M$j)e-;WRXywG z?1RlfJ|^E3eAD;5;$*Ww*NX|!S`ivhgzoTey1Y%gDA@v&8J_i@zNTLzYYP@LGu7=A z>t?poIi&M{u9|AIJ8qXNEyGjYQ3Q;)Yv(HmLmkfGD8bDVaH_?R{nX1F zAGRLcIVlo`&-&J1$-4V=c$e>XKm~bmfI*dd!77rR=*7SE^0tsY&1bt};2ktzQr4Sx zL?Q;rs8YhCpYj6lU6f=lDG_Xl@oRqLzx5UUU8B#5f`T*X_TzZx2k)@{0ka5aP)(>t zlQNMmY)?d=mDm87n^lY-{(;AeZix`UYEDc$mDC+s+BpiB{uSo8WD@~i0S3=yt`~Cv zGY4WH4eapSUi;z&L4-1#Km?d$>O5v#`&G}Igo|71Lj+QJ7)-qP(OVC@Kux{#&kTgw zts?-DrwXXUgl(j`A+AywPD*4|hG);N5p`lTg*z22o}FX@yzM=!M^CcyO6D5ub-Qud zGC$@^VxNGNQHt+CK{PfMPkj3GT3*T&3+5mcH1jt+^Z9L_C8xd{{G0871NwrBMnU-W zF?G9uq!KiCx5F1a`kh7jyK43Zz;b4Ghwy8~*ru%3D>m)-cJv-+1*0+QMydniafwR!JzP8yo;a z;}|=>^?QDz<%K2_hyo@u)^0nCkbP=r^!2D1%zen41I(b?rW>aO};m z`^8Gp7nlMR;VjGwY-HSg7&m{KC*BKH?JWn?5J6SLnB&|BKDw2<`#;UbSMNk3M71TH zzaIgOS+WpvSib({zuOQJn=yMg1F$O;2$cXD3bc+b#nlgg-k2}RREJ0b zWmo~^i4|sm2`0dI{J?K)`f!wj!l4;%UGZafmOE#w)V%b{-qj-2KRd>ayBCN{Tn$w>qi8(I(q&Bcq;FTV;AU?D>bsIa)BHB5m= zJ?_W0^gN>>zyU%TKm;ndJ3^}qKlGzNm$4%hccC2#m9>010LYtjJ%Xw4OjbWVJ*E4A zP@Q0Ki2?3~a7z*Z<2)Yd@mt<~!-AkfNBJz6oMt1&4?X|4OgW;7K8kiKrV~9`=X7*H z>peX=;I+y&*}H-kuzS3Dv0P?=Feu}i2SxErWaO<-`juaPiOsB=C8biWG&Y071A%}# z;CFxLx5+b0jy&$+)Ze`ECH>iE)Rdbp#UQxj+An`5Ge=6%M}6CJN^T_t8UHCFat1T1;-F7M?~(s5@ktiFZ=ZuS+)?EU}3wzg_*jHCR<7(^#d!V$X!2U9r|!DCgMdex?#}g)ZQl2BNsApCa1C?h|Eii2-UnX!yOA1 zv~&!Z!-)m%?iLx>T#dy#pG;%pF!M%yFJSMVqfiMz0vk(Z$zaN&lGdK~xGy;{hkig# z8Y&?m1{9_VK)s6I_10In06Ve+ojpYa+`xpXK^UdlHEXixUD0)cA-lwoT0H;zzw_s= z|IOTv9{m*$eDuRE8o~$}=ESiyYM4RYLSVpS6jSsVb2N6(fAMEGEDNG{OnT2CnK4p< zpLpSOe(a~7QwpilQ8RW>s_h>m5EX$4P=lkGUt)4EDLVzmAEUA+)_>oM02Za7Q+x2C>DSoFaU-?4Gg(} zy66+4BkX2-_)c_I>?i?hTG#l_XFZ^w51%);_Pf9Soai9WumlENi$qBf0z1@H;m(Y9 zivRqKM?L65x?&lFf>IO)6?I`kiDQFc0Q3xVilJj@mC*N^0ycRIUv6Vm*?5h#HeZ6Ze%DCG48-uS>fek}sT9R)pG z0*shLsLeqf%yVQDb0#oQ1F$y0#IqogH35i#<>-B&5%X1aPr?XU>VB2X492$`#Obn4Ajz9>o z3KoFis9+Ux52->Tn7cb@w$HYU10=u!k%|CzJokH^%$#@83On7Pb6I-(9L=vYIXMxF zc)K=dN8bI{zg#f_L%obXK#))hV3r{P86XN}plU&)$h_vt01!K}u%v7uFk-~^AdK0v zEH70v0@Os<(E;=yc;}x6Mi+E?P*V|c@0uU8<|x74i9$(fj4-H*(^E(_t5+kr3m3&T zDyCnPgxZ0elEe~(qpRG^(m5i7%^U>5U zaMwMDckwop!4-kT+tFN@eD}9J?a&+s400=9J-f9|384jhL?$TE0-4}}RN#b6RQS*d zn0jatAe;duw9YM9@A2tfAv1xDe7(0XtW#P!tC0v=-q6%b?_4Jl`?KGQzhu0dD{q)G*WiOGzRtB6N@+0}vx zQ=(n=*3V@@HNZijZmZcY_979vv%3-#OyR_AW(E*h9{ScGpZbySd176-JO#`IoVZA< zfQdYdiJESS?9OkVCa@XyUWnIKL3;t6;iF7txQ!7|LR=>JY{P&Xamg7ROsuLbL{z+` zN|$SZICv*Y1971#wFSoq%)yS<;e^Z`>O=R$M|U{8W(WcSTv3(J@Y3J~A}1^MNty8f zzT(>+`pD;#tiS^}X3by%gN)qH$%M49jS<{~X5mI4Ftg=Ws-CI>Oi&T?>H09Pf9PG; z4*+Na3(Ooa;S5ZG$36MEjrF6DY9b=*240-eb07{jdfRJXw%G4j0>r^)^*F(4J4DND zDjZ5>QuYCax|asM$FXDkXSu z01h!Gp6#oX2z6HrA?BPxFm)Do+nTG#?|J+6k|6;sF{_J|76fb=zWlM@u|D0Oc}Wc( zwu{MphR^Okgf@&;5m|}YP&R=l0pJlk69X79b4!E`#p56UpWgM>KZpc3W9nSytZJUA zn;8J6(7J-K<+%KsZ&xQ>+m5pzu6k1L+n6ZaQvf)?}ByBl_ zjLcwRhPhM~2xjOD#K%G|EKHm=qXjs^T~*JKrk2=R*thp-6s56}!sm9%RbBjgK-1&R z7yVz)9tZ|72B+i_DFY0&hS&YY?VIgNOUo?}O=_plm0@KNh}kzi`zQbOcYd-$l~cnZ zU>k+oRqA3uK}1yoZo0X>dtA2y0eZ|Z?v zM?d3*$r4w@kxWZ_fittYpNo`+ORqIKtXnHd0MMi;Y}!&CLtsy>c;a+wzpQ@CGp-x} z>`aCz9gTVdoNVwTKk^e<5)eYofH(w1MUtmA zfO-pz6wmt37ijO2=l}HYPG>WGcijPO#YdNA@nj%uHTdrD`?*OI)QNgENMHgZparII zbtkvz7GXpsd2w|p+{?pRlmo?uyN;|yf+w?_39y_?fEHsY{(=>k97GIW_r$C<%`j&Q z#Z8=}Je6kUv)MjhXE;k$rRf$JPTp=!s>uB??dz$(f6L1UhN@7509T}hWZ2BO?n%#` zU_N2cv`WHgr>Nv}F78KhNRp*g>}FCpV76Vrv|WjT8aIxIKl-~GTUwv?{`AJ1+5#&2 z&i+I0p1R5EIcFz4?6E(P<$!Z-@Ge^M64Z6|J_qlb^1}2_=^v&BX)+MzHm1l_N$9(g zgaG$#<4-^(77WOnZ8Ksa34xfco$e5w&yA*bh4^j8IhAJiP_b}eLag0d9v}YbyI;B> zs0aj0DS^NNF`@;&?B9L=DPP`F7#YQ$DaBrmG%dV{GkgF*>lM{m>)Q^)fDpGwJd|vH z*FG)xO!&6%`t?yh&uA|6<~DM*VQ3Q2;dPx>^R6#zA>4fX)`|x8qSq!Jbd@6cC3_b zK$v%t4GJ2EH~r-eMr&rtEkYMhohVI?;dIvdPV=w0_ke&A*{Bwfu{+HoDv$;CKo&ZU z#O`UY9jiu+f+2c9a?Y7OChuoJ2s^WVXcpBv0AfJ}P&ye2ASvS^G-u{;4tJ8V1cRcsrl_SDcN5ozdo0qGr_@ zuX;xDGMq*tAz=R^2%9mwTkb^+0_XG@@BbMc*qJh3_eU=cn5L9CR>*CsY#LC;5_C7v z>{E7Y>b9R&^$Hn>s|Fr)$pKH}Kx1xid;42{o&{zNb@9%2y^P_Zw2w!@ zC5I3__j&_WLm*h{D9xOWa~Cio?3`E^Z<(_xdm^H`&Pr2FE_J5)-}~bm6F|cLPa;DF z8xVlkyy~Y0Y+XuUfaY^EaDQja5CfnL(gIPEK+yR3@O$3&-xdUO6-w(8gM>nMngCz_ zfBrvPSovS=U4N);X<2@r_xsk~`~0{w_d2GDN>GG?NKIv08c`TXWrIQ_h=CNSCF+0? zS`rRoXpA9RrbIuWoQx)nOp1i&4=bjzv45fvwZBYn<=lJj-RGRW*ZSV~>5uPQYwvUJ zb((XXlOxRLusHYJ%{u$6@7Mc2@AEvbYuEA|vQ1G*6&7`lGPW~(K|sH2ho89r{Yj88 z_=a^f`4ZJ}GGGVzvp;%g_T@}hzV~@|+;IW(1WPtmLIQST>lQ-T%b4fm!=nQ*gM^Li z(gbeNB~QcP0Y(sEa$h3@rT_|Ng^Q!Wk#CRa=Gz2hV4l!nm17111)?Sv(K zC*v7+>~&aIFD5f%;(747ZQ6yK9H4~*!wwH^{gsb?49IMJ2%i8U?`Gpzo`jPMc z!ax6RCu9{vqIJ{iw2|T6!mEe`bij|k`fJT^Y+K48w?NP2X9be$( z&wl#eU4bz;(RFhKOo4=r1h5M1D!!y-j`!X3^HDHQFv}3!QOUO4#JwZ-3qy$N7ipf# z5pB(!kUDuSvZFN;XBE9TENZO4$gItakwu^*`V~|CFDq;7?9#AZz1*61287UcxR?>c za+Y@c%Fj4Ais+^yPt~JnJ;U{f9{At|LkgtaTp(#S0LO-x{N&q8_q5}(tJKMuXWa&T z(bnm;MT%7O;~gV@O@M0qp74ZWuIrz;@7{|E3qwkzB?l_)o!jx!ND67(rGz>9o!MXv zI21@}W1qB%mZoEyT%?YU_rB}xi@c^64Sq^GaiZv1Bl$*N85e8s$XQN^U7ES4$QQr( zrI99sz*b5Fu+|t%=(vQO%;#t4VZtDK^&qAfE7$h$)mOdzg?H?tQ*%4QNabAtpZJH> z7r(l@Dl-}-7h^yP)6@|Jr@fUkkeiHkGAC!Nsx`#OeMVoz^108x{khM=ZUO@lCcsg= z8W_QRXkm_;L?s|^gV#X(fVHfwhX#aK<3dC|EGLFAVGF}gz3kE~u5=P&YR`x%AQLGV zA|tvJ!JIj&!&(SrfP=Zux+Cnr{HK3uK&AATn7Gf8D@d7$nc*Y9bx%}&IfBCK<+lEp{!(rfB!Ry{~¬<+iO&KxTH|IAvGM0Q8#R(*hq0y z^AUp!30;?FI$rpHeezGmL74q+*08JH?Wh=x>52c}sJ0~9cnKo4G* zi8F9C!WxrC0a28R8Ia;CeMmOqe+{j)js56YNfl!Z!cWk~s>hWB17LEISlND9+_fwag1ZrU8Tx2RmH;+dqAHpCGU<-jCfrhJwLx=ZoK{ zw4b|qDTO3T0g|P3x+thoFsjMKZ4fmE+hGbafn61$mi2Px*FN{q2QFmfx^o7ir~?J8 z4d35HB80Z@G8v=yHeZjWSZfs)2e@ zV+7OE3M2H&mHTgqSnniei;XT-<;K@Q)AtMlM(OW`;N?>^G1S1{ZO3_Z#kd>JE~l z5p6LPt$N@}!G}NknXex1*!<~^S=!l&%IK^*lKEUQg&|v_ZUzxlO)fjA!_+&=TxHp1 z@o2fHqvt+b-~J5Dt!=s4Oez5^flvR-wg0$`WjyVH2fyHi33MPFhFYf$MaorA4x$Xe z(BV0EzV1lR`k#XZiHM@;fdYEKg%Hr?L?WYK7|QDM%vMR!%5w~G5?tMMT%$5HwtiAO zz7d>VJxl}O4y#m0R@qg+fRqSiU@zhJZaFNkS)3E1lpvu?TqOP6tM8fvvjj!8Y*WYV zz!CB858M~?ORGT&dfx-ci$(N3+(|wLPR4|$s*e6_vBQ43kL7K8-`uQ`j4-0m;9~aOvt#xG6 zI9&B;3>5+oKlD4D9T;l3#BBoO+0|p_%6+sc7SMnY5GoC|5Z#f5Pe1UX{Y7*(Q0D+X zIELT=L^)pXXJo(wv`lnSC}ano{{wG6vZt*?>%f#?9g#}!NGUHCWxZ|@<|(|rg-8-Z zn2M1AjY{QcnESBj|M{-)_ptnvGQ`l$n7uvhi$y|7u${ey!$0%z zZ(BwW911JLRbUAm14o88|LT1!yR`O2z4^5jKJ*9o9|W!{A7OqSd5N^Lbf9zqTQV(y zBjL`M{NkZ5Mt5OYuO0m3kN@2@;t;rsd>wYkbRD?jSmm8VE)F91zySR9-`*Q_*FF8| z&$f(2>oGe54g$xPuF8(kT~T)p-6N4#d3NBq9#{ec@wzwuV41dx+1Cxn8F7ZuKE75z!aBcw&n|XD_Eqx2 zOix^U>5$3XuW|KoxaNj5w=j1n({7Fn=_p~{d+E!7Aecl1Bgp|uJFdNB+YQcuVpz|x z%vxL1&E-ls$K(g9d5w#D>sDz8>_#dvMzVmN37aK{#AU)Nhms0V4}>PRhs6B?Vb&h*d?G7_nYmaV<@$xLhRM6iP%Ub0B-n0I#WNN}YQgN9Wq7cEaJ6f^O!Z z7-WGA%urH`cPyCK+V8mwRrN&8XGfP-W{UxI1&jKw#^>ISS_h3$)0@!Q^9>;t?K^h+ zR;91HC=y+56aWAPK1oDDRI|E`(`qGATn7|nGsbw_vEOLzKmcKg`3lW;jjYi39F`Lm z>RTTI<74>SCgdo%LEl-;zF70RyaSX{s$E#D{@L75C7#o9u?L;Z=OF_#dNwVddx+L+D9PZNRREW$LSkivop+E_yX zl)aWU0bCY zh9xLVgw4i&5tH9equN4%rFMnECT(cNlP8>SBo9tT{w!6kE3g7LugVKk6(2;ge@M|`7IXlm{|b|9~H0NX&jr|G=a4Kluteg!e{m^Pu}HVL?ePUqWfH5*i? zumg$&5Gvy@PKf!|Y~l$6yGO0MZ-r=8u#mSXeYq97P0gRTLJn{u!lNnBSeXXZ%3Ky5 zIPQu>wdib94Yde5X2u2(^r~wNw~%411Jq-(O14hW!?n#>tXoCgp!D7z?3=9DD%j{j zG`qG4sx}gxI@JVPdqxTcW;h4%Axk%y&sJGpu5ltk)HMQtqi*pmv2S%QY%#->{E-GI zBp4g3Kp}{!#y(S0TPvN!MA2^!5Rk++fF~OIOn;66QAcjOzmKpvciZ7<&FuEZ`O32` zDV-SX>6kajMh3x6sK~dL&){Zm?X9U2n54#x0t*qU+gja`Z2;O?Xfbm6ZyxL3@F}+w z+ICQJGtAo)e@m}AnkVP+9G&hEjg{t&_1hNc)_vN0YGlTG>Q0QB@%vGru4$by)TETL zRb|3W#NJEwO1y|!eV54?T<8z1I&cfD;U)Yths{UIL*&roh1 za4n6NO{Il#a~nM$f*Y=oY11RJ5?sW7dx4%g)SsT(|9QQ&2<5!*|3rJKinqq8s9gU! zb(YOm^I%Tg+L{eizwv6@u{Ggtq1)%Q;2z6!zqJuP*ZV&aM|8>~P6ulSP{^$Pu8>-c zAv@vqQ~mK<_qA@+-kkG%|7O4cxqkN3aa^;Q*zgb3vZ#%>J5L?y4YZ8$h{hR~#n)r} z{{*@J6YzTK*jD9B$2Jg_fofO*07*)aSIoxP3|HF14n<`(7=>@{GxD+D|A{!L(dfr& z>0Ju1iC^P4vpKLf3x$Oqkp!BnsbDBmEi^{koX^aC+#4R)o+Q4b1fya~6=oOpw^+>@{W|JO}fO<8!x zrB+`f!L>xFg90i@m3@`4<-)v$RmVx6k|(d zaB^>EX>4U6ba`-PAZ2)IW&i+q+O3;uvRpZCh5vIEUIGcu%i$Q|9eDZv4l^ZHsny-l z4og+QKl{QABl{#uUDg?>z{{I*Wt=R^MU`n^^9>xhDV-(tUBcq6j# zjolyC@9uW*`OI#l;#jWooz#yhzDK-~B~9^ttnjz;f97$0e>;DhAAS+F%~YrJlYjJJ ztn3Xb>@dO&=lxz`F~nrYD!Nmk5= zn596(t0F)`n_}i$NXZmA#msk$F2-z(GAlScC}Kn~onZNfpRxNDxqlir$Ljv4akD=~ z&M9>NACYqk-CrX2k8%4$)b@CA8()A_g{&zE87V*BG_3I-bn@eOvzq~`RZbX5*vTwr zFa>1g+E;DsWfWu@I}bNS42Cq~rstlpsKu*>Xu zTj67OMVIzyEWoJ8`espQJ4Z`zcc#^?M8LKB1?JVfVUIg<;n8fXro7g^DeZdVcJtJ7 z_Uvo6y%^E598c-1=d)`Xr_ahV^|SZPyPa+4W^Oq`x5V-<)*+=hvy1iYlS%~<$PS=% zmjQjUW{Fh_shkpSt!1{j?x}g5JdKDAX4n1gERn6=8BsWgAp^=5=dn2=_hIUUESGI7 zTgvhFc5=vT&*MquP7t|JjC`nFFsZM=hz9AEM`E@FJ?C*2e>34{m^ELRW6jVhm=VxG zmI}dQR;80}b6Oon7BeE7H_aaQIWv@E+pJSMy`)|k!=vq4wd9^^8ey=~>9YJK%-MTd ztMbY2UC0&ki9B9A$GO2v#!wH-&_kR=I@x%MSXs@SXEK17w?$iv6lZBl<-YDPuUeQ< zzmp6hgynR`%Q{iJdsg0leJ7+PhcVOjn*uxSmcyOT>8T!yB)d7Oqq8nWT<$3OrZ{qn zby^vr_9FTXg=HerA#05Ye+`?spIFYlrTOdIc9c`C^Tgd$vd-oRROv%_T@nR0V;rRC5kNnPrl#94qxN$+W+#~VR5U)LCQQZu}&>WXu0 zQjlLn>VR*g*#+20|IrJDNH`$+C`PLK4_r*FRQ2ifvh_d9UHT!2XBeNFs0n7X>zst2 zLLer*)>j4~!d>Ak&(ep(Dg=@uMZg*X5*Sr$EC^{kkViGcM4_)nnJmYo@402N9Kgv5 zsn?}Xf+KcUG($4YnPb4CL27S_;+Y&d5mTyLUe<0C3B%@07tpBC_G5 z2Femx7U~jKR=ft*=t4P^U@bNQe9CCw>BlT5%@E@R+zH45?HAPzzMKuc73Sb4vM=!u zC623uOpdHV>2viOCwGY<$Jnwh64EdYWZ*#fSLQ|#Qoeu<(849=%_Ae@BjHcV8^&-q zkTs|;LQ}2>v!DPt;R3wmZwAN_i)*em;6X?ad?ho1C5Y`bN0NgxcF*c%S{lYn+?7$3 zyc(8A5;X3RE)cWek8}+#Si2QU1i&Q}kjiy%R}jBM3CskF+}&qR2)|oMwbWXOdN+)+ zZGZzC=p)rq<6|N)zeqx#+)K7HPw|J!CbW6wA_yU+;E8;Qaynvy3Bv&g37TAIu6j{% zR8LyK&IztHW0ov<2N4(s0H~?iZHaj$?iyip5}>sIghE=C2VXLC>-Qvf+m0gDr*Xf4l{feLDc3lqORJvv?_OGN;0GX z#t`Waxx#<*dq#Kst}H2#Acat8-C81$9sXoWNf^e22m?s&j}L25raVLfZz`y^Q2_to za2=7dI94XON}|Gn5zx8aP$>{5;-W>tjtBD+!zE2aTafEvF(~BjbGaL%5N>VKNRb6qyDNZ0j0LGQ+1d zXDdc(LkJz&h?cv~ZY?EJ)eR$nhN@%jot23BaeJ*@D6zv_payZdBw9iTjlymHnN;$m z9^4KUWLji0mn{w9NAe_wL4gcPTuYkaO<0fIoUFspZxk?oiVKSDq21vws7PhmaTDZ@RJl?RT!WttW^Qgm$O%9~vQ+8-cATT)&tcy^_d?K6)1*fhxXVgK zG%QP*&p?WS?rKG>c^t)F1Cn8&=Agw*K3$3maiAqvy{&{+=^hcd{6ZHKIMxYrM;L=h zR0Dh*-I5g_=aNX0B=I(azR}s9>Mz?%_GJX~2EMDG9r27N<#woREj98r4LpKFz{2;F z562;`HH)NaE@)n62)sg2dzdR%L3%&X7h()&t$K)UZrU(3YJ8giq6Gp00%&?MjVU@*6K0~!h9my01pJ5QKQr}OU`JFOBRURHZhDd?k=p4Bi_J* z;yX^r*JtwSIP)wQ4(#Lx|3Ve;)R;k{CHaOZqHy-ZjFTgEFA=0=rL=fY!&sq^?U)10 zRjb6PQTbGdNAHuL+)+0ND460C6V%cY#A3R*H%nY1aseqd-{^MH!BjH{&Y2rRzyx zZmQT%-}?z21XJe*o&i;Nq&3<+Vc~-iIFnk$t58K7@Tn4?kn;e83^jCZp+w-UsDi@) zO{6$4q1c`mY(aI5eVT$%w2KF~`3X_Nn1o~8iGCTVn^ss|2@;emfG$d&5zFd4pxqd1 z1aMFh)Fl;#=LP>*#HwOtQshTL_)(>%>95y#TN z6KRgq$u%v;K&s2^HWl3Mj)Xv4Ash|>$O0^A!^XnAtsqVU(^Llp67*^$OKP@59?fgl=Yyiw`ssLNy7F{ByBj7rB)IJC6S zx@+KStPMyYTRwTU;Sn*}`1~WO4suFuNtE9M@l=c_N|j_154i;tWkT3wv{+4)AtTSo zUc?pkF#wo3SOq|#XXrW}WqXis*obFD6_zd*La_N!g@ZYCzo* zb!-V^!jQ=rCt(=shtOQ)7zE{5f3zQteVH1jxMFJNFdTmYMT#DoWxR?pCyF6}pv6np zKTX=mIaU#Y@04?uJD}jH3N;1G3|Itp9WJT~>4lR~Ws&d^Nr3^xmEb?g%)1w{#m3I& zP=Ht>lFsj^{XPW?&=kXpH$l}dnQ`P2d_=k>h#z|6-7Zcrej$Xx*26WCSKu0i12DmptiJ6e zuYFjU87xiOM0|-G9?;Kbd%WxopA=Qi!l>HNBreXm_Nin%rPZd9n!aXX~1w=}d zV*b$hR2JEbcEDkv8H!h13nMD7I%*wd7M<~q;z*7nIRWXsgvOiTx)G)%CeZ+sc_ey{ zspFU?9?@X!Axh=P*svSy-(AI?pT^4bLG^Le?tsnI1t=?|`WggePXjSD6r|eNZOEhL zn;7r3F8|}y_hKPb_Xib8Eo>Z~#)GNMSUKcIOJ_X|j}c&=l3Q3AgTqS;FdgD0yA81e zbBIw494uw(xG+TZ4J1NA*9x2h;gwQ~H=7;x${=X0!BvZhk4{}k?E}g_V!>np#3NuG zp}>HViOjtNgOs(RSb;jVxyB9xTf8Ia$Py+elQ)o66xVs2hO7by&_QBn25&~_pG|03 zd6GXI)*G_}8l~9J)ah(B0>=y zW6R2(%ZRpesO^I2Kd_N>BcLOE?)VO(PCk*m_*7gjigwosU=LvelA<*7b&+LmqdB3S z1{K3G3GF%6H^L9~pbma2!S_WSJvsm+q5`kc%giCXr)+RHmwQTFdQe6YiYHuHqRD)Y z1r*qM{FhtX+WCXNm!RY&(>*x6DuhEdgbGA@Caqi>Lm^Q_8y;&RY3+GcpUW zysvsT4EcgJXf+_2DYiZw<$?JeRB;D5#4)8$WKc)VQ2S+s(V%O7QVgVaP^1Pb`D zoLg+rbG|OAKZUAFg1%$lP^-4vF4BL`=fJ%rk#(3DewO1$` z=Fw&aw1zT6mX>>|&z*D(lN)2+P^429jS`N}kP%VI zEZFNcmK@wN+v8?t&kOxfvAAK6Vck#nOE&bzh^m-d!lnDpWic9XISLJmA85#_+>)yS zVwnMf#duDHOHzlZGD|Lv+Y9lji`PIe03Sn+=<5}=oq-ia)B}v9Btj756*9`WTZ|01 zL+1!C+YHH2)kD30uGx^+kPIA;`~a53!ronqD**s^;nfgEAiv4)?GD*7l<3u#4@hFj z1@c5KFN-VJ%k*JUEcjSirmtL};$}yTWLiqZZ9rP3;+ z1^BeZOI8*_lqLZO7bt>K(F+Mu#6SbHR2;>c96|>A#IwMQ$cu)>qa2st+%H4w8~uiP z(4G=Gw6))PJRpL>Lk6R&7mNU+I1C*!8vZAFW|VZWgx3Y`g1XoNEG^pRVOUY^G2jc( z*y8D_9vsf$b<$jbO^GSGs6lyHJc^NG@6G$Dtv1s>h`PxIDwxtP>!fr%az5z^L#-gr z5$piUcgTY%2Q?x4K`X)TY0;vNUicD8ZlmNM-HZ(}0GUL$Z@h*XG7; zER|>n6hD^A1Ai~YhEky%N{aYoJ52#Za)MC@OClmA=(}U%swgQhc=x$mogc7)*FgF; zGR8_xc@z8oIR4xR8ir3Kmfa*jxAF3FI zsD<4!p;lk-6PZe_qr3--6;hkYcjPga4edrfxLXPpii5^bgNWtrg#aQd|f0&GeYJK6&>7_h$#;_Fl#Rz7Y73>IZEb9DA% zjjr_e4p9Ur%IVQo-7jLyLbWvR(XnGic^f+D4k(Uc6u#~X|qTU~&&Tm^m| zNCx~sh(-xiv}#PlNeYxQUVjB2jlj7JsfRO&q$GMRjTi>9L;XFKe0{gThalB4t{B~| z#+;j829fL{vnE-MqDJgiW4$6_Zz1!L7*#d-SjrQ$SBj$vL=sADWwrz(>(p5G(z|3{ zf&;b5f{#JH)+l>?I;2kZVC%3dHYr?8a?hSHjPLFeN`Jb=;;WZVkl&`cY~0dzc1IwS zc4V!#a1MGcXe(?kwdB-7Ad4_Lr)WnHgoH`6#J7hIk8oUDMKdC|vV2@ZlN`HhQEq5h zJrYu(Nk@BDEJ#(LZ@LHkNf>Y^!gWiS0qmy0v~bIBZ-Zd3Adt!dG7GJ%zT}yZ;}sbV zG->b5Db0zvwnw(Cv6}m}6P&#$9+X5~Q}*U-=@|YWfkn=fu~$g$SJQl9Vj$JQ#3(K1o)NiO+uj9%OJa8v8GMW$jpQ)G zpvli%Yz7dH|BX~E`lqQ$!Z4jzH}C5GJ4_-pj5@j76w&1hysxpC@uFY~scydZz2L~GfqWe5_r#!HM zQ)24*3uB$1AQsIIQ;M6q1oByvWOWtm85WoV&}H*#RFiWR8Aj0|IFK#GSAHG^9^b}% z_eHI|sl*%5Qj>aiky+|6l9;sMkL$r11ewtN_OP3X4ctK<`Db`x1&HDuzR?*8uWfYP z4~s2THnfYC)iGwQSPB8b$3!OyS*(4*LUi+5dle)}+@LP){%Z@UzgHEGM=U zVd=ZWr|Uejnrs39qTVqnz$!R{9Kah5YdW!F5=D#*$ER*t=AXnJ5kc{T2~O?FsDl92 zN2NPX3%sLPs@fB1CnDI?rddr}bV~-n$I}$GrClT@Y%MOAf)~xGiz$OT7(k}pD2@g9 zs?Y*%i7QG^RCTS?Znmjr>1q^Fz1l8?j;Li&OIy!6;nsD|ATMn%vk;dFj)1Y-DA)*8 zx54_7B}s=za?{y{g&@p5IURhf@2~~B7Xwlc=%OAC*~@BZ2b<+Z0pnr$43&?*WGOd_ z1oIW;l;i=aeve9+ETBE#2{JAtKNGT+8{s4YBDL`m@de9tn5T9GSxICsoOJwYXA^$v zO9#V5CbNJCualvsXVp2nP!41d7ODkbL)C!^KwgdSPFMnq%fxi5z5yYW6teZ`q>&C8 z2AmN#e;zKyFsJOXnl255&{;7eb@T_y^>DF{YUJj&&fR))oakT1jPVow!z+w+DAiev z@=j2dJhdxd)kIp2SBA>GY;9JVU`bl#;=)u^#gs@$WN4#nEl+#qmx zDauaA;|C%S0eu~F2yfE>a#U{(pMs-fZPlQJ_(&gYX=5Nr719Q2fk0N1RfAmQyCi8v zuy@_{ykn1eW1Y*B1jq!vY5%O`s7-nxYpCIHsi+h7YgiR5CLxS!8E2<9d)9tZ`JN0! zgLRRPwZ<)!diOB;| zB6-mAgG1jAT*31!9gAX#;}g3}`v6q#R_2&Gh@l}k zxh%Xx@Uk4l@J{CAEn5b)bl?@AMLlX?2?aq}XqP_5&YQQ+P{Pzpy~GmhTJk%*d|`3N zN{T{IfjWr`8L49q;E#%+k`FO$#Y~wk9Q5%EAV4RFi(rAZQ9r!`nbqyB)-GvUv>Q=` zZ{%^`9(A1$=9Y)ZOr#={MK!7u_#CTgcM3lZ3o!HR08+{xj%s#v*ah^$wZd;0`>m6* zMqWfYw_FKbgyIkzNVM`c$cCwzf`@X5bR-3+8%L1l(%C;ehv}daOtaLV@dk@r#1>u# zdW3;wNzagF8#tMZ-*Kq9T`5_4hJ)dUB>?}TjRhUPT!>{Ia86GyqU9+rKCqbo214(W z4pvn=?cB3ydd$t1>4YtweyXdkvnSy|fSEcZM8~9jF_xW-qU4F`l7f@QcRIGEgRWVp zgn~{dRQO3;c;Sje>K*_Jiv$C37f|+QV2C0KL6sFo;X2hjs2-jt8%cf=KWZ;J#3q$Y zX#;_n9-ODVMD~cMuI(1`l=SQI;#AgAZ(um=S&CcEw!Qt^0R?I+Y9iZ`jE=UGG6r?@ zhb~A=3L#+~sa7$1d+n&>Mrs8jTAp%3&cG&Jx@^_mN{8@4($A?*H0y!Ea?>c6s1mKL zwoV6RY&z{sXpiFPyfMp0VdDhQng|3y7@AO%G2Zowmb#o#8tHY-Xy-qwe{O{Gb5hk>r zC4S~-v#Q7#P(j)-Jc70ad$9ynX9p$M1PaaJ>M3dTRqBgeR~u4$r#eBTEp^uE7$!@e z2vNy=Pz36EpxtWgtOU;E>{=W~u?q5Ng)?mJv1c9D(ZS>Zi)YO>7n$OcM_H4S(MdyNe0@0T zx{{PSM1$Oesi(dfAQsjsbJPw{fS+KdI>_H3>N~X~?bWd5Uo1PElCRT-gx3Mes%-%A zyw~X}9hlPK?R}$dR!AMfCKYG+oKJUD91!D+Ws7B;o z0G?3iI;KoY0G`=|+3*|wJmEAU7af%f?e!1|B)Ok3NLI)_d>TLx`HBwSUJX&fY4-$L z}_bd6%~-GlOjQ(5j!Toq~5vwH@`GV_uVGXBvC!#QXk%iPJpYaZ2@=+$fD?eewAYAtT<88Hl`YC%=ch~ zH!h5Ylh!A9b#PR=l&b~N5#lY8C8W7yZIEz51-PXYwZl2|6&m&Qo^*=ZxHS~YP+w$I ztrrCkfEv2aE(QRm7*3~ovyneAy6O95kIR{gRO)4Ml_z^JSyHhCLl z#KGa~Gk^3Q33%S7-W+9bD878v3m|ilXcE>E97&t+{Lyandf&QRAn+cpN(%U^uROpS zracoR9jc>fcLOl<=vsB0!bKpWs-D)6%ECdy;Rq!ZE_ei!A_w3I)`84gh>rxK2c35c zIuUZp>j#U?a8}wNm^uM;w8eu~^;%@ylml)kVjvBZj-0;QEmDS{YBz()M|xDLm+0ar&NuC00vi9&!0Ssh6yG}n!MV$GtxtFATbiKw6+ z^6Pq6xb%EFI0(ia^yM5}+t6t*kDZ;=Mhk#I)ybT!P7{s{Kda2van8d?2)SYyQC3N! zkD>BrMlgEo+*Z4YI%0aL!Uc%w+fhSj_M(n5c;$H4p)l z^`X!7NGwbz_OJ302`QbT=K7z*-0vF)4~q27rYdZJ)UvkY#%k_(J6so{h##DWDvG`) zf*M-y{FR@_13;{|_P2-$a78A)FJ5Tp=yM;3?uXU58X;ZMJ5#C>)G0&Xxx+Z?5CK?u z8OQy8DpL6qJ~CaHWMeHHp0F$N(LCiHlqcohF-UTzqaSp%xmQf))`72Xn!K70IzvMx zzgYkz=D?)*tQAIZ#z)al+QKnZlr*Ex_#{)??VTl;StV}z9tnU(U`-A3;2wH(opu=N zP2;xJhXQxG3r6{M`So=_9fmyGh=Z}4?36*=hd^p5mI$&e^`k|N|=8(<=_=icMSqqLeVangt zVO=-AV-PtMx0TpPFT|f9A{GP3qpBVR&!f~MU#c@7{fv#Pk zQFEN{W5;Qn0KsSAN^kiqbztU`^jb@c9s#}Ez{Pb-Q}%$%9bn+ekWJZ@f>c6254@kz zH)VmoTOhRN&8>Nk(+40;vr66o2Zz8&fwI?q-re0kw}0<6=l26#Byyg*1Lo}j000JJ zOGiWi{{a60|De66lK=n!32;bRa{vG?BLDy{BLR4&KXw2B00(qQO+^Re3IYllAEzQR zod5tG9Z5t%RA}Dqn|FLv)%NJW>ztX%%%n_uBY`BeP?X-gGz(xuKm^5$jSGqm#D*fp z##KN?EC@eDu_1_DQ9+O<9Rva*5I_Me<^M0>vxF~Ns<5z`|Svp#(CZ845*HSuRmw41_attLx;~QO(re$9)nRJ6HAoZYYOh*3stXzy8Sx{Rduux};@ii#2t) z@n!8Gv$cA~Oi`=sL8YtW?5Ah83P1jEyWrCK^(fvPvbiMDh76%?+qQog8yy`DK$9j-NKa3vQKLqjKY#u|S<);#o8C1R)Si{M8zs3n zXJ@yAqRx1|UL1}pT&538h2Rddf}HB_t)gt1Ef3vcy|=dm8-}k?{i6EGGO*t4-kq&m zx4MD2ix%*TNP$wS#xzZzVHmV++g5GgzMZaJyONQSL1}5J34{VI0XvWn{3;?f7w;Lt^y!2yK&e7^ZnQBj#SH8qIT;m{Je7pMUWfjpoakf4-02rLN>4z8@M zti-Y`RZ>#Y9=Hpr1WEzXcY;!Ct%!VW7>2j7u<$=cJ?5ZL9DYy9yt5gyVkW}r?>(#- z^;cGo9?!0)lTpg}y5w=g=fiwN-DGTBaY(+|__o|%w8;oRw;+j($EHyDW=!6;!*Z(R z9#KF4B~bFa0T2EIwCe$H0r7@mNJ2saKv&>xpry@bGXe*Thzqz0xW=zqfm?tNfTkj1 zTm(NFm;hV_NMvN>6{Q~k_a_41XZgFL>ABz!TSU_}M`p7QSd?zqv>rRLOzxy8#bf ze6VV2YW#$~5x7tPjsiXe>I2fKQKL(66%jEE!!Jomz{fz8&1Msq%O(DGLPJBvKio7; zF-=oaQ&X=frWyZ^lTbVNkNB}O_YCx__hXyx&E6rh07z%n4;;c)z6tZAAUhJoAd zR;FnZ9v-evo;-O)X~%!drVU*iPjYI@NG^Obfv~vM^q;xCGJmf8G&I<8B-5(jX3wS+ z_7pWeY!@}$ZJL4$_30EG%}WB8 zcLOon(*rPpWMBYbh{#4wY&%d;US5vd?f!?hTSTxdD;B6Q45KV1C54NmizijtwKX4y+rSSqr% zg7NutaKt|mNHL{7$O(by0WKgFcq%L`%%@dp72pjH4*q9|ZUbyeDPRys}GsZKtJ+E>_9w(noo z8%Ibwl;-fut^tI`tmmp$Q_`9~a-kVmgI`=oWPOp}&OFQbv}w#*+n9!(=Br6vM?mR? zJ(M0B#L2=qyfu{8R;s_Wc}*t&0KMuHNcAni8<#;b5tvq7T>R|sfd`R52vGh{?70~5 z0`}v_kNfHZK(E&XZUtifalKONU~O&fyJcl%ewBQMGEFHG6|!P7yzmlg(R3vKbdJ1k ztyPB|$aBN@u_XVCIQADjj>j2-JLm)<@rN*7PEKdsdm`^1_X$%f|9*XV79q29sF$+y zuApy=+C8T@H8O%XPI`Qg&I_@;rPZ}W)S4I)HQ8{^ed8@S>zW~?^v;)5JOO8^C($Z4DTCIa&x7$m+UavzbmDIFp z)8l@X?{qqSUaxnNVHkM5UVwo>bGzNHF3zxAaZFQ9pr{XYe;fI7K`}eF{KUCYQH;#? z*x312)RGHFo1?tLh_nac@MU82WndMZ503Jks0eDhO-YcCAj7~>5+~|Hb9ExS!+%9Z zzcunrqi|p!OOmUfCFs}6FmHv=PI{vn_HB>qQd&~m<)=KxUcU|OW7xKB8zOSa$2x~rcs%OdxpNA@Vn zWgaV>L|I69$M3Ad;%;NevO&v&@!j!RSri`iaiRA7g%N|Q-)**o86Uh3{oq2_h07xA zkI{;Yiw*5P{tyxpLV0=lL8a7bU}9BO)tnP2PI%9pIYWGWJf}{b0#JE*dGXq%KWACi zH%h4+d_LbmpU?N5QmRhf1`B_$>P^wUpH zpU>9{cmS9M)c5&(L`6l_0MBYKbY^vRbw?3#iAYd>ettSI4!9Zk9+0I=mtIj!6D!e@ zimpQWjGyS!aFfyTvjMU=Fn4t0;f&l+Vq{-p~2?dm9{%f ztNFKzZ(kZnWqAV{b$LEv^Dg`8^2(HkZR@Z6X5Lwe=(34-Drb@GsAJF(V9Mo?Hce9u z!@%eB)j1ehSy_aIg(;uUce1p!bpFM2!otGDG)<-FNOg5Jv9YnWWo2c%%gf7m1&$30 z3c|815fMZLo6UA*smJHGLW9=LSnbR*_+Z*&R8`cYd%yb^44IR$ars+qDJ&gB>-Nu* zcV1EH4T|X0{ny*>`cCT4Jj3=e`6SqXW>A|16M_2UpZat{OGHXGwDI z{2n*c_4HHL@eNB`k#}wcu}v2~5!9~WbXI9YaFk9;sxE%B>v=-`Ko_#)_YS*rvvMj4E#gvniLt$YdrKP3na)EVmaj_~aETp8QgsQ44)w_3Z_2Z8} zQeIxJI(6#whw;_b)p)&LRb5@Js;jHj`Sa(mEcK{R^4VdyeZi#qR@Q-qO?%&Z%bg|n zmpv0QBUbI+x`>#hmtRfozj@NpFYm;=) zh?wx@4Ct}bO-kj8f?;Z4@CI6kWy^u}UteVv<>bP=h*?cdGuUaf^f1fuElvIC7Km zm(klMMHh!tQ_~k$WJT4ad4uWzfFFRjGv#Oy2_wMhiP+XET6FW;lb;d>~B@R zYT-C*-GQ~%gjPrRdDEAzt?V;Xh>n>{hgsVlrI;lqbt zc7KAm!hFK^jp~y<&$#ySdV?s+A4QOvjAoV=yme~KHI%PUYS4pCdvy&(lBif_l(~DNyF8( zBQl>em^$)qVj~JeC(KxpCpGG`{c&W(!%K01LWu~5VTe*nlu{jZrXoEciJS%e9rzaT zU93(UHEKj=W@Z8~6L`chjFO0m2z7Dtw|@QloI7_;YHDhl01t}D#Knsj`|iB+PFl2R z!O^2frM9-#2D}2iZy3g5%d%A9dLrVF1sen7fk~mEp*5wYr6{FHPEIB}JKGMt2FwGF zdpsUhR8(}i9IxnF;VfW+&1NgDt*up;q~==IZWcwhSWSv#|AWft_1O|P!-lNlv5CX@ zc1;4SR>m>$uXl6b@bUbtZ1v8ZGp)8I230+ccivbkyJl{dk~dR~7t}RI`D+>QWR>vw z5)V%edyObWWGMI8lkRDT#H*`qNh*r|!I6=)Y1 z7e|8z4Ja=!_Y*P{_*&0fAtLU}ZQkv63lOX`B)fq(4Z|>_qN30RQ-C-1x0QhNa!eDE zAT9NKfZ2e}G))N!2@%sY1$aLIhRf-6Uc&s;)Kov-tMu4ax@YCmBIv%m#%M~p0wibScJVFQ1uOnZ4fD`;j>R@y*rzwOA!`wWVmu)GrEfdY6}Jxs|s#kiJs z#ru1r*F8D*DGuzLN_hMd+75dDrI8sI28_H(&F<(^>(V+{%ze2vmK9pB-8Gx~F8x%! z8433v10_&=A8Mi(Wx0Gxmt+1CknLXr+$VH{Bq*LZpcD5QNEiNRn zYuO$Pufw9M*0;ROuA1$;ie6yDXK+=QKBYav=Sz0gF7O~%FNNmxp+V1uG2AoaQ8s?R z1c$qfX6^1{=AGX$>~3$$$jNHU-#haC9d}A?R*r|G$L}D#{^BKb+t++F`WB@Gb$|2U z&-$d<@c!Y$Qm3*5M8&mMGN3ym5zA~7)$!!S&pOz;KZJyJv*X=!P7ZGZrV zcjL?K=RQ^Av>3A`7Ge_4W7@-ucEqY7?_Sw(W1D@^i~1X2}Px@h2y= z?lF4t3x!{Q`3Y5J(X{FK*Jqv@^hQSaSr>{|o?yXuN+vBz<=k~6m6dhYVBflX@kH<7 z>DNx|n3tkN*+mfWp(sPW8wxM1S);oD?!7e&$bxKEO4Vd#Wl>#S9a&XX6`z}%n*fCP ze7@Mq%1VdLX2a|CQczGJN~uozMEi*Z&_XHIzPPwp&YnG6*B2=w;Xs#*YIyyDcwMvW zq~G~-lYM~T+}vELsi~0@Cr*f#u#OkO=nFJZN;L#}1i%33BqCvcukg&7Gh(;fu`H{n zmNL<;6abBtQf*G2JSl-*r0q;&hkA)ujKL+psIH4nvM>J5g0`;DUsQW`&-wA-__ZX) z@4@Dd$(!_e&*4S+kyKW;r8vJ8Hd{5JA?rMiyNy0H?eUza3&gx>_lH*crxog_TY|_M zuu0y`oTmnF{6exOkG+}d+R(m}mwbBrO025EE0{`sUbsl?r%#{m&!`7l_+JBlzoFAp z{}X2_^_IZp4wE#Mf?NZXM+ zbsJuohI2$8q)?z+)2xwiTTKV|bWp^tdGR5gHmASfQ?K5DxeqtepY>?E>6YRaMl~ z)F`D?8?9OaGBja+^S9(#menvjJ6oQ4<{8oTy6ZGyMLQh7Xoe5g#IiK8=LFzz!?9z> z2o4Tb`Y`m@b1dL>;6y;P(IEZplx9S&UTXmS-XR?m6QjIdZyVjNhsI%zhGPM}frf_< zAO0WpXj%19-Ui%Ih59JOL1f5q8gAZ%dCUE$7meyNf9UG3Mp0h2ReGg8Gj__)Pd-2N z=1<&Xrk-4Xa#~f$2`4SuO^2F^JhgTs-BVjh&GLI?+8vcVeD+>xdCf;1$Uk{c`oNo? z|7Mx1aKqE*WkM&6`sOKN{;o>`$Dp7fKf_K2Fs@rb-ES+fP7~ph<{j;E`r8Gr0@i9Z z=eITY_@!7W6%BOI-_8TufG_;0=-ULAYC>k~Z*72j`oKj8sL(P4ISU|$Mg|057>A5CwM_{{yv0JiW?NV7Ch9HLOzIZxqT~p|m z6YAUBdXP1I{uvw*v76PVlWj91GJS*l-!7w*UQ&nRujcw6KIPTQiHv;gBcaa{@tych zRplIF*#p(G>b)m=S;e{cH*DVK$;8-Gh5bG+lB^@!)WPoMR(+o+#<#e%51_ahm&=8Y zuSx=lbeoRbUeQV9trySD&CN}{$Ud2x5P@xiU4R5#FS|xJ4hgWKL4yVsumLz1Fg^>| zsM{E}>#=U2Ygt*Dzd^PRhEmF_LzBQ{pVM5Ln0rl&KKp67{Gv7Ahhl;!I&-;io#l6yQ;%)D7CcCF;{iK)FVA$8fOT&(nXy@mk&MC41Q)Sv(_^*g82 zDOFWfnN@Mx+wFWm_`xv(ky?@>_Uz;m^oz1lBGNY!GlIr#AWmDNW zbJUS_rt#zx%^1}GeMS#`IOV0;A-nfn@?BD!VTK(+?zSGXbPXR$}DBur-qqKCb)lQ0N*Uha}^KLD5{R6T!Yk(lr zG(~4kqBTBifoO-rAz@)*qJ!}7^qg&4i)_@-dx55Ix0~SLV1GkyQ@z%1&CIxjgoI1~ zYi1#E<1~nQ&JwG@l&X2~^u#Ll`nr#TSh8XoZdXCQ^!lIYCwTHAomnc!k?vy6p2k#H z4y39qAuOTZdpEwbXm_8TsC^xG@JLQ)NiBLwspd~8`6sj3J$v?K%a$#olrnV%c@WSb zK&tY9&A?)y02r#JLuvc-kUs{>&CQka@^T08s%9qIw)g|;)xg~J^mL9KIpRlQvWE3) zqehJ=C@Ap9c9XQ!`f>1g{CqDWv%|x~RdR9x7Xs-0G6xPE0C-q`o0FE7Modf$J9g~w zr@IE~y&TmZE5ISU-TtW8>$UvNxnW^pHl@@&?Ti`P(OG@jQ~twhB`9UQP`lMo zZQHgaBqZSRctpp30K~<`5gi>(OiT<)DJm-~nK5GqAt50W8yib%YASBGTRa{QQBhIE y#KcfrTYCu|!o$Ofjg4jO*s(-LM&fe0{002`61^@s6Sv_$j001fidQ@0+Qek%> zaB^>EX>4U6ba`-PAZ2)IW&i+q+O3>vmK?cmh5vIEUIHg@SPsW?yaO-a-({wxmb#y( zV^z0Qs;Z2PAb>l7pxpV7f4}bE_|>)#iQ@|VZ?{vX%(vfqpU`S)B|jBQ*mMHf#B$-m#@=baLFk$M{Zso(i~)K>zZ zr&RgfQuzD&xt~k-4-07dpV!a-xU=+YIX)Nq@mS@Lb^3li*dJd{Ki@w7c1Iz8ev!*N_nyzwja(edRlX1P;}PFy{2&>V(|cLrqw>G;d3_(9kLHJ8L~Zk^ z)A`9?dWfv}4LR&E!VTyBxx(TWb3C!~XN)VR_qCR4>~X!IWPOAit0OWyHDa@fT;j{| z=e>mYzFqIfn>F5f2YxaJE=G)h|8ReP;s5&a+lB5`nunmpB=q4E;HdoKHWWGi&QB4L zaKCG6Z+(ON{NexnrNAa4Xm6Pt4>IeByZX!f7P3w3`lz!LG^#?6q? zIrtKig%W&?v4+4pHsUk5^O$m+L?D>6xVg)iQz`_fMZGt#sp0)>wD|cl@IoLg6^f*( zbO=_?mJ9XPNTC6YST?g_)tYr1C6!#Plv0auYSdJ7wOVSet@b)vYPne}t+v*B8$I>} zjG61D*WP;XWALUU53V)%`N0=voN4B1v&=f%>~k#AXXRz9th(CjYwWnw1_InJyY9C8 z9w!`9>EvUloO;^nXIygarkk(ba_eok-|;hQ->Ci`wLc^GUq>x`qZVJF^mFxB)cA6( zKQ0jjC#IMYv6ur9uZjQ(9ThX*Le5c1mycT=g<#6#7S3Ndb1Bi2rEMA-}gkM~nQuOEeV)oXZ)xGs}W=}jRn0sg7n`ybWbL2T~^*|B; z*K34v*Pb^|a-hw~=ZR>iZe3E9oejoiW_@{-CWPUjADnTa>f zjb!3ha!K=12X(WM0I9EtdnJ7B`{XRv1D4#(=GNQzySI*!}9&lcKj)G(0vY9|L? z!);7n&ONsTf8UjOzis!VT_zE7Vdql83fL>z?O@R$Plyt_jD#|g-+EBC=O%=aR2q4c z=;1aygBTo7WE_Y=C6~Qz^3DC;sV?sItU7nRjchtRHf4;E2^|XD>C_zeS+nOb+pS7= zM}E~AMq9VdxzawB7fe+&38{)r7mTX=Y#x_rgq|AWGFEA$h4i!oblPRH%MJiF!R|l@ z3l7T7kf{w7DtR~)84<`n6166(jqHdw!~Zs zr&!B`-DPH?iyUmJ1p~~?X_c?-dJA1Bdro=Bzr^L+h-!l+GG+@ ziIHO|Yn2h80v1?WsH#MKI3tn%Jeb)b2ii$~^GCI_$1W=sYSHd$ojYMAX9oFmfUr?R9Al94bBVK!DLcr6oVwvZ4F|{*M zXJDu8+E{ZQsaIv<<^dvesq!1h-elw*K;o!re9fkS;Ad=8E74M0y!R=0z&v+Bzy$4cEhlFAC@jEu+yvY;5BJmmecSG+@W)_a?z zOs=6)NQDGj!(Z1iU|j@ZYMu`kCb(_I1SKk-F81fLZ;X~=~)Uqxg>SK!c%>m z$01p^@{b?&z)wYBKp}Y}0;T|nu%JKzplNw@Fpd)Hgjc8tP>`?(;)DAWmBA|RbE!1* z3xQyk79IfIaux!G)FN+!UlSx_4O0hX{cfZZpJQv%?!CHII`~@j{a$TOz>DlCd9R59 z31%JDL>(Mrz0vN(M2h1i$o2q1>#|H~csDPoa43#O)ozH=ph8j=uJ30P`4f!`tXUo@ z&Onx3F7}X*7p6jdaiYN~05Ap_!zZ9aZut5NY?9ms z9^lb|1Bcq9mM-60VZ)Gz+aUWA$u3~EkWE(lzx~q%Yk>SAG}MhCsDNO>6Z6&*1kgz> z@N27*NohfA-oQhRNK1p+P;rU3bi1y(?KD7YdME}dDA|G)FVG;!b1;~>5yj#McMVbF zC1(Rv$T9pOoOjDE&gLi}8)}qdD5zeH2y{rLM$s)Qej5~bg|!(B%eexe)I$;w6R&v5 zZ4i<-+l>qp&av#ha?N11Czu8B+v$!{1NFTTg*!_RVU&O=m~dxQ>k+C9SOf=D5nf17 zSQfbncZ$Jy+o0j5Y=d4^@iIWIUXR2#sk6X0Gp>=ph-h= zA`1s7!gI`(txJ?}`xVe^STr&&QjulWH$_1Lv)MyMhQiW)60^L()(s1=!l(eG)GURy;$nCO zG!$|U@EXfig_Iae2~h~xeCZ9CiP)vo$V%3QG8*Dx?q_jv7O6lURa7D$2H`X(2b-u< zktjt5tx}9gGT?MY#4WtWq8yonE5f~nEI}wD!a~PCv91t)Dh^V&0b>h=_EpXzo2wiS zWPsAxfIp>cB!2>6*T^vb1c(+MhP0I{J1nGo{8?dfXf~rz(JCYn<_z^z0^A2GChUhH z(v3hv)d&&Ty+A|gcL@V<0EC>cDRTpn-&q!J<2gtl@HJw$nqTDD5(1&%@2cppwJkRL zh($1&n&{I3Sc=!lNI^KB%w|Ca&Qs7 zh%RWPBaEt*08Zl?y#J9ENh~rHhK2zYd04atULoKB$s*Y5BGXtm5E8?OOA*^d1KBtc|DfgYh6OX&7o)K11;fcq);a)bFeeEPXR+{8 z;?gPE@llvy!WDtbNZA9}NUPNZ?}|4y1Rx=McpAxBe;R9(fIU%EVs)lNf?1(b8wo=V zcCLh<1XFM;$OT?~9v?!w2$>2yVYeIQ02QK(s|U4Mf#|C8zH%#=VVnh$s}s2jk+5KR zS}C`3Ka^MJMgK}R63DkqfwdFjmWHXh$9zY>L05^Qlj6kkkUzX*0~YvkAOxg95tM>s z%<74Nmlu7IxGoB{Vj&Q`MqPk z8%KoluTwS~*N7V^VT|QFa7qeVRiD=bg6Kgw#a|5I@P5KJmViM3;Z&rpDwr3Z@9Z=~J;suzcdUQjmy#K947R&f9g z8{g8AaTPqu4`Wa!PzE?;g=5 z22%h5co9rRO8`Sv@DWhxUhI=wYB>Sw3s04yEPz&e?fJ4SQy-$Ah8jQ@@aZDmu`%Ri zt;k8x^@;YpS!{?U2T6tYYzSh}lfVIb3VjtglEg@004|6py&ho?5}itF;6~JP0N0x1 zeMvjv;5ai;nhna4o7B^Q%LPhffhv-*Dc}&TpDZgf@bHE~29s!G*PotDJSR++WrmT4 z=tW|QNW)WdJJ6~+1|YVzA}AY34&(|j0q&$is!_D?>T9?nIJr9764oXE?jolForx(& zieUq|0s=(t#sqGc+}eOUss?-593>r)bmY)DUILh;!O$T7u%kSd>KS81?Q?l0>>aX0 zM${WRrOMe2)VqV_Fw}F9fB5|R+V%}*k+Gm|{-@%vy4Wfn54;A%vkV@BM?NIN*)So@ zMIiCSiBdq(Uj$GEe*nfcNyf-BL51jS$#=X-E(h4&EJguLrg%U+I}l!=Z$?B@R7vV2 z$%GO`#6oPRsh_!4=&q`lz%#3DKt%yY1snoFS;@dlP#vLCdOniGEMc+lUgaYY=# zAem=Jtp|ZbK`x+OM6-8Nkp>l!`qXpce&bBTsNVpJU~QqD3S1!n;K`_7N_A6zkqgOyPH77I(A*aLA4-;Li zT_iNY;Mu@o4+(>;u&=Ug2#RY^OAGQ`n3~ouLOAty0;}zd7znvB)W5;C!ALWqGms3H z02EL|fDK(v&;;#^3PK6GZ74t(lu`oYOOac>vkSGV`iqD^NkLFHsjCQdm^6Xpi~2Nb zNyUkMa6!DWR_C4eoXQUHLLFNA(ihihCkUa9wJdQp8iJ8kHwF%II>6MXw6* zCUQ~lqJAQZ4Y@=lrLLiJRANMGlmZJ^?HJCfB@P(rPN+gg`eg)K)Pk{o=kVZY$ayh3 z3oSBtVY$&$aEv6L#KY>?@DI1~)d$Ag9IA=@!B)c#yovTtknIUv?@gV!h#w(6_f#{B zj8XDbvD3JyWHOTKS{BYqU#<(Uv3{|GP5eqVX%{nYwWp~- z$_Ob~x0ClA8{Cp`H9U95>Vnkd5u`xPAqr)wxE;WsG5}Zk43&t~9fsEh5 zdLG;)S`knYCUXyJy9y8iCay!4P<2n)ViHj{UM8W$C|2^n&Bd)ibZVPU zs3W487IERB2kJ~*DI!>si;zo_C})Kb5nX+d9HmAT7!^a;Q-QcFlC{yOO6>;V?$lI- zy+Lyoo|nbDW9v=^f*sEnw@3$%L2|+AB+r#*CDD<8p`eOP%}|jfPf0RK=21IAu)lB4 zy&p%q1?+qS?8sE!UFk}epu>4~FfTzJ6bDKNrtxRW$lpPQ#EOba@ci}EHv~zw0a@tD zA_E6eq41X(<}kutY5^^j1}ug^GqNFY9rCwTgb(b*Doh6m}$QM9E9!gT}-1lF=l&h7b%{eKuuH`iFo7b_5V*ft=U~l@jl3 zfADa~=sQxwsi7bWY%-T!k+10)F$WkjghdyCwSK(|+76_2NW9I>3kmrgxppTOkUZ{1 z2wQQ>P5W+_i(IYwk+(rRQyvAJc1>)B+{G0jIwFa1Sp1-@T6&E&oOPQq)}B{1>h|YLF0W#46nw(jP&gK3qQ?r@Ktfh({Pq(K@zTwMVyyxxpGH%{k4r@3bG`QT zP8@1O4>jCzfpw&uNlkz~fV>wh6Ja1fp^~?|Wxy!n5ge~x?H+Czgo9^MY~asqB$n6y zU>n-b#3Hpfmk^MUff`2nSGBQ2&-sNUey*sYDjlOG<3`_^P>rR?sQ z>iZy9oSs%o0_Ot=+LKdBX;#eAx$7*JZVwu46qAP#g-*KfE@>#cnC+S{xkwSxv3C5Bi3yuf&^~yt5ULR z+BZ-;EDxne9vsLRuRpe3$W$VxKVu7!M@zj(?EqMg=v&|saygN#)Q8|mdm@_Y7$0$r z1{`{HN03dL=W3g7XrPJvLUSVyG2R$R;lN-14nm!+UH=~fTYQ1=a?}SBiG)D;6fDvX z+E<18X3!o0&S=O;CrPXUfQs}^k5@WOjm18-nJ|Y7*xeSu3)Z*8&)RrHD`{g{Tf+@N zh4w;&aZu;HIcSF-uT|7( zs$wLPzGd>Cd;)0??&u8U4h~OUwP^EgNWC3{A#evU2%LJdRJ;&tsT0 z?mlj9wP2v@DIu?o$MxBVz=@IaO>NR%H7PM-#8u{lghd5m*(=X=eHhRFIrI*ct|mM0 zOFeXzQj|N_3Lg_?MvVAqt+SMNt4B_a$SB z((Xrmg6KrRq#|R~B$_f34`N7%JGjG<%i7=GYGv{u(2q6{1q>K?7!>nXkc3=Q_yF4E zgSwuq3+_oBv>HS9<|Fd67mpC?S24L4u7=C&wywsYrJh#2Xn);NN!9e^j9?q#%qFzln2@B z_#%F_NrSI3i5O~RKJiO@>_DcX^e)E1LZFLf4EXMrFr>oIxRfFEf<~%T70R!rO2!7> zF&;HV3K|;`q7Chw)c$b;j-$$q>Z0JjBmyy-Mby-??!w zKalvE^GzK`7E)#gRiKeX*?C?7W;RkrvoV@i96|`v7n!w{23D#RSVYz#BpePDN47<$ z5NCjp2O@UJRLR0(>ZS%vw))dNZJz3tDO7;2RA9@598`2SA))z^QchMEzFHroChC{f z*(`VmW}q(abD5ghV4da>qhK^>Mu(95U0hXn5E-H>fc*FX+LE@$apBESmb2sSBR{9j z@Z0f{&=nbnoI?FK1R@5vQfm_qU0fR?p}-;Q$6^}y(WC&nq&#sq>pE{TLI^#I9O-Vl zYY;kC@mj(DlcGptSRdgwSV(ZMP!QNn-eZGGR^?WVlAg#!5}^-uakxm}VPXyrc@niH z5&E0d(oQFEa}{1922>Ks(Pg>RA&?xEINX|3QV-S1ys%E|;2{tuV zk+AFNW3tN&QMok?6*R@-;4`(8k<~m}?RIr*3V%`?u&qZrrh=Mo>aO>u_PFozXb37* zKfHv^I7b6T&=78_rsR3h%-lQZn$QS(@62b@Rz)yL#S?#EC ztvMUqogJ^zk&<=zj0WYw^A+Dx8bumlR7?{00#HY7cOlq8*XKddbD|bP%Q#;|25d8Ol`pKzgup}P9- zZ%VPE(S5->ATqHw+KuN!zGMuEqe;{O&i=aIOF^>Z0BQ{;q2Yj;$|k@!0a zc)$yTQ}C7=^v}|0Q9v#vo={S0+d&3dwh^1(4iVP zWC?i(f}#~_2$MhZbX5J8G~K8)(&x>Hf?Tl1Qlse{wNfC}cpPCG=+!ZLg9gaJbZK^{7BVCDtCh1YGq3LaHGWlBUFFCK8_XLV$S5 z2QEYQ7HvdmzY|%F0qhqZUT=y<5?bPqXL7A+lUs0xfY;;)0Ivoa-lrZMSV7VP)xeKN zZxB|Vy1v+iI@D`k+Q#v%h}XuPh@ztaqD18wWRM^Oj7UKT`Kh_UFr}n20NPyF zJdywG`bW+3LHRyMs!lpWpSSy$km9r3ez$HnGH+e&^JQ`Hkv=lo_m;m#z z(e+@MX#j~g+}h7FqQ?VfZ&cbTN@Ga%HApJNEExH$-{yx8YaNvjSQ)-lFt{^#UR2JK z8rIHcx|r?-tKBWD=QW5ZkKqES8+C?3&AEkY>M=r^MrI{ZcuT#WnT%1VO^`Dxs3)U+j2QY8GPCyk$El>*weY0Dq4 z04hF=Z;E1WnnHs>sOK&7q8vqiY!%nke6%AiaO1Tit5~FhR+VD5AQxk1TFORvSnbOG@et`>oZ}`JaZE zRNI0|DFB}f3F!y`9@Y{`tG(Mbq~QVJCJkYp^#`qC^5fOofHi;|uvdLM;iP7Hu0vuX zSM9Znj6hIi)DxkMF^eJ5;#J+wtmi74XT3}73PT-@W9+Ow-f|sRUH!GMbUm3g4JZIQ z%HHs(!9;I6w(^J)?N7Q1JHbMzmuv;Tt`m?_j1iHpxjjtDq$*(zM6ae^L3>RO^5baM zSzvaG#hd6!^Sb!UH*{nUMlqZyM96A)*4tAf=+T0SIGTST(7+{S$9hl-+C2pP1|-)O zh7?HBG(DINpA@T=d+d}04;`Rw%|Y)sm7xM_8d~QL@M>fzE)Dj|#UrZHdHwvGyyJm@ zKd3iOWWNkgoJZ3<;1;#14fRdI;}nywyH5)E!BUfb@kYE~yMin%? zS@pna*0!7gQOF4e1ABYF8D90$G+#U+UDIZp!+Y7*oEH=lriKRs4=cDZY8pe{xi`30 zqe7aen}{3|D5?nx{a!Wp6iU=cE1JFMg{+j#tA3g`Mo2SmylQcg{XJnM2+N74`H8Jj zKHYjJB4oXFht7kCP#ogPG=>x+>n5wtlaQ%f`GeY;8yz*Ik*oH_k#-HvsvX2$_#mfE z%U)%C2K0yCm7(wNbOI5=ji4If1%6)w%79k7RlbmC3RnF)9#%72@T|W;@!M%!IxDvE z`CL0-0~Y_w&Zeb+CW}#;7;CSMk-9X`LUC&@Z#&XP1C^8u*b3|95sn7!lvC=YCQVyv zLyIb0sG}jnei4i|)MJVZe(=`Jt=xQX+PXF+O|w2;gKpEh{7))x)+QvlO{|2FbR@k_dy%i&DZJX0 zhb(}1Nfk|>%x=o!j)qYVxvGw)#Nl@|{bll->5(^^`?>g{N zq<}OXjW)tuXj5x88A^Nx>L16MJCuzruiQy0ADw~4!NX|0M_Xe(Uz6M+H8k}`4as4? z53+*Vl!dRFk_hf<%ksHav~+L&)7ykwwK=@_pDDRo|H zh}Wui?B>9+p~r;tOH*<=d+nzuotV)0iI@nYw63WxZRzouMJemqYl%+BShM1OkJ03A;uokIuSKjwI$h!My=^TTR1rNtajWA5ZXh(a~9PLd6neck?{taa7wxJM`*@;cg?Q zNz+g7(5ad_NPfbBu#6g9uy?vj!vX2-Drw*jPpz|_05c2J&S;dVX$}YFiMf%5I64z@ zb@VD~TCOPx5f)5tQu#m8!saPq%sOtz!x5{-W?;&AJ~7 zff`_)POAi@>CkvBE&?g2z4HAzvIDE5G!I4X2qPM0DU{s16d*N-{nd~Gd5A}XC=7K9 zDA`bB0c;!3`A2Y&V9MaH@!Zi=`wX&?s>>y)Z>%YI4bq_PNIe)*mR~n^@?e2=g3Awe zU(wYXL*Dqc<(o#;$f9v{yaQv(YQtH5Shi2E1dY+<`rg+3S+XK_G!?xjKn12Qq21H_@U=AUMD)~R1k_KgJj5>*R05_GKvEYtgSko zxWZw>Lp9=HPJC+Q7+l^6ANnY$*$c8`ZCL_8mv=!&ku=P+w4bSwiqUkWDr)pnBjyx= zCds=LBK#ram9&qfnf$ml;AjaS0#Mq5tUBOinwBUDX;>hcO9 z46@Zt&`?7~Rk;oTAzvVI?Nw+Rvi60ygj?GnGFRs^$yW^=Md-K)cS_g6khvnuE%~H2 zv=C;bsMV3u2V$DWy|PXNH6#z}8p`EeB>UBllmtK+kUM&EX{I+|cA>ggTcw}jT%ugSORT*YsMys8If3AS=yX6 zpcA}nD0akv@@tMw6YBn$se@{coR=j zRuJrqn)_;g2A9>ghdHuXM3AJz;8U9prrxC2!7boI0E|2UPo@m&&f3`7CwDQ7_7)v# zMjqWL`lWRS&6V2_M|o0t1(;cFO%2Cj+-KlrVs4WX2cbTvRu776P>KKdss%+RL!P9oo5~}AK@9s zL32jh3Q}*y+wn z1?phf#I`TY@hNHmq6Ld)1A#5Cp%r|=jp|XUolr!~Ck&l$)nS5s5;zNS-EzAK3E%Gl zV{P|f>snSr9;c5?ZWHwi{_);kai5wM=JH4*WEjb1H)?ynXxA79z(J-(K4?}j;nU@T zNzkEv^VeadJ{KH}@d8R;Cx}#38$MewJaqcZJxz0HH%d*|IJBXL#s`gbfZ1=gG}S;6bO2g3u&z$f z>yDB~lF%0#QqBjsS2NR^Hqwpx*D8<+i%^6PX-R2Cb!%BA(nLcDD3mq3Tfl%u^R##; zNFsLilPQF-DbcDk9UAjgXF^B6O{Z#fkOn4>IyQl2genr;U&$|>j+>gt(&@PI4vg)+ z(1JO5H9Wk=pfC<%Eu9-Vc(1N`;dcl{{XjQ$e6fZo)2s;`NP0bu4o}4^1j47lFT969 zB!Tq%4Kl%Jthk!$Gu{!LX+trk5j7>K9kOdvvzq&z7}vZ%CaEq#rOtL&QW4(3+VXH< z56VHKT|QzVAS`WVgD1eQC&7AF-%vWQWNi&=L_j+eQ@aB!bWvMFV#VHh3KMxPCzQ=e zib)gylz?`CVAb$8IS>MXcRLkx_Q-n-;V8KHLaAWL3isSx+iqpfxJ8 zv12!D`=Az3;+=~JZI|R~(GdupMJFU6jbhrm)6k4uA>hoXbHQ6f2e`>4su@kOb^byd z#QkcT+4D36a~#dgO`HaSQWzbyTajRZ7v-U0Uf@T=yjT{<=`%o)-uVZwdR!3ODl@zx zNc9BK^&P&WAx zqlP|eTtR1XxM;!4dY$o0$^h=?)t~GF1WlS3c&3UT6E2RM+OtwkuJ`gLE*e74p!zO! zMagSs0b!5Ke!KBo$955q+O1WWC+4`xV(u?za;A7n#1d#`I^LsWyPx=gAv#Qr6H(tr z=dij)^MFy)sJW0<)@@}HbeN{796G#VG=e6&yQG)I52+5=Ua8d{PD(Q5)aeDMmbQ1mI{_y~l zRA+6L&O>pVB#@yBZR2vZD1Ns$RgXG+=>l|UK2#f^2OR~PBCt-=nf;Csz#Zz}2~kr8 zzlP$8{{6cvARn4;HHVs8fnYbChE1SNNjU~ct+6-j5cVO&W{}(x5xo1&_8JcSZQR13$53gM^A)3y3C;Vhtwyd z{639dq8oVE@`mZW{&5YtX$_Q0Fz6p#;+9;v7nkn6`%jD60He2QcnJUi0fcEoLr_UW zLm+T+Z)Rz1WdHzpoPCi!NW)MRg-=_hDwPgq5OK)RI$01Eanvdlp+cw?T6Hja=^r#P zBq=VAf@{ISpT(+!i?gl{u7V)=1LET9r060g-j@_w#Q5OyKF)jRaNm7Euu)~2)inla zx@~4sNimmS5yP+Og+LWYRA!bjD@keiuCIFp_WanVn9G7o?(V*6R#6b zZQ2IsePV%?WtI4xc+{i@i66PHc>Kn>WFn8f*2*Ul=Lq%S&9R8b%5WScC)-3TmjJ3L6R9by6&3=sfA;A9DQ?xfF6$ zz{oL=4QP;EKlmT~o~>1!826GQNuc}1aX#Weco%5a9q0SlahfMU=oz@u+y0FPF#Sn- zy{*NLfao@GaoyIGJ>YT&7<@8hQ+A~gRj^nD-p}Zp^1#3?(6{Q%t#yvm2OvwcO5Xqn zhrmdQve$gx-P1X@fA6&B_X9|{a+a#sDFpxk00v@9M??Vs0RI60puMM)00009a7bBm z001r{001r{0eGc9b^rhX2XskIMF-;v6bU6W)}8B1001BWNklpP4-$N}LDYXw1~CEYX8;I-2&64TtrGk$h|0kG{pVJXqP3sUW$?Af#@HAm_BUy$ zD2OVe3b>$d5_Lh{enM1{1QLRjA!s43ST=;>B!ILKWDY5@66&ps`u_>4;`4~WKXYJe zY#)t{KZ6nG&zsr}jVeYxA1q;QPsOeVt;PnmMLfyu(yznEJzd;Tws(THIZV(9a`hAh z9Z&$wbwwFd;>X4qvHuZtc>br<9>JbcuraUtD>~}-|wm^-VKLib0 z{oa*Q-l+pd{1~N=jWJ?>lL_5W=^)cQE)YkwTESLO5eR{Xp(zYoMDz$8*#gcn7yb?!yOqz*&S)A1RBEw^qeh&o zrbI|0#-w+LMD&N6vxumwimzzoS|)+d)I1#j82|U+XCqx}sB?}qzoE`E+#K>9JJP)J ztxNEAj6}zX-4MHookEgT9>LW$BCA)`K+V^RXm!T3=b2fAk4|9M{?o{0ONm+*1(8CE zA{K>FhfK!8akFG3i-!A{IB`3rQ?^if>TB?==Gyc_*#vQj*BT{WRBePRgm^!(TSm2T z^AUV5QICSyELC;$94o%PFcL8gMdbPBV&0RJu|;yvo$v1<7~P=j4mHOyGW`oDb^#%k z3X#Mm1zc4bV&;Mv745}5y0#iJi;1|^ghJy8Ya8VyjU1r~M>3QY^-vV;rL7|Z*+R+- zN}y{VZnhFjJ8=76qAdCpGq&xA`!nzY)ODz?u5moG-q;OEbv+XS@)3Jr<1PYX#spCF zGnBio9oVm`-J>C((MZq$VHE=_f%!lL&or@sPM{um3FsS+AamY(9RtewUJ0-am;n@+ z>oo!EfqnVTGZZr!Bh|ljVkd1u+@h&5#(@px$bm&1-hL6G-pO=!wIkLpPF?Uav9e}n zpVI{6d!WaGa=agac7g3Nj_fX_Xa6(~>^PU0OvQ>GX6h*q)AQ|3%-Gw9sHZLrkjXBevX8gpR6t{MeZs;YEc2Qdljy~X9PvK$< z)W(JQoFbSzmHqpd)7*Xzjvk?W@)~9>-%=}WJrF0M0BLLzniTbsd>P7S7ZD5i88FpI zw4)>Qv3~G$2{6+T!aoD`s+t*%{VgJ;zz2Z1A?_o<2E$1DfW78ts$oPiU;=QCA<}1n z{{hySOATiRM1a$Qj{|3#k6!~e8{*agA^QO5bo1RXuno8!cv@9^b4`meV*d*!b|GcK z(Fh&U5*|5vA%)gk$mlf6A{)p?pJv&DmzmJq1Gv=Hv6eb&YS1*0pw^rv?t);akg6&q z4oJL4;N}Qal=0frHMAXCNw{D?Q>*S_Pw#FXIN*>>4lJDn)=n8t%HCr(brNV8fk=bI zewr}>-fuSke*#Yc&M;El0u})81}-yi-wC{I)byTQTRrf1Ij|g91zcpL{SSdPs@ge( z-5mk`3b-7Y4D1AcVx;w}cpi*=8o^Y<%q|3KfDYhE;J;M$@bQ`~W8=@2*av&I1r_iA z-dsblF~$pbEug!375h3*2RXvH=u@0^Wd-*!Js@)67QBX96SVm7!h9}^A3{fOi^_AmV3Xu@77|*>Q zF|W@L1GFs$J`2n>()k`^2K0@VXpB}+4cr72n)B_-1NfBz9|f*6Cc%FLb-=+ zmjb5(pEerCEoQz1nb8=jp8N>^NfLU$O^$`8L2zwsq5k#|%idPPZ-29dy$9YwNT(Bx zZR9Nre#Ol8ouod~4T^D%0l3KlJGtLQ<`KMl8wnEuuO1695ph}3wM|5(uBYed3|fvZ zV|vRDB3EpNhcf;)OlrSW9}Sv)LK{|99T9mM_$Kg=z>S93`>z?0wlUztMjLp*Xaenq zsq1lEvt2}<1@;?hy(<^+6OjVogTVWYRQ@KgCa5V4XKx1>M?3H$@HwMR>^s>t8yoq@ zpIB^sP|XwIb_)n$+hVUxT1;C|#$(?-o1L$`he&uFQ>H!2oTc}0?B%Vb{=FL;#qQrA zCZ1jKfHYghljdS1y5$k68X0eU2`=ip%-8i%S-Jx!T2Fyh&W_hkhrcZ_0lMnz94Q7e z?cT(p1}$$e>hIaP63-Cd2h3;B0RIKFodnScut!y&AL8?=z*`N__8s8iT;?-^WVdqe z?|{KJj*hACe*wfUMzc<-NF1yQb_^MRw_iXtAit^->8z_&HP0dt&Zz6=q>GXjGd>bPRyuYpOx zy}(1pHXjuHGr~1B7nJmmHLo&;#Q&d`>;MUC6AMT~G(Whcb1>D!7{XYUJNvR07fGr-4OF;b|h>*85jfD-iBwegk z9Tdd2bEs)DyZ6q7vPcR@E0&P5EFdu#*Z27gE>(AdUf7k3fUeft5}oDU%@MRk#pia!aVrBO;4N zZ^I=4`U*1?O>LyNM=7x;KuHWwyb4qae%||=t;c)BRy=S&{WPq=L_9P+1N=PrSZfa8-LaSoyaP%n$oBwaHx$(SoV{tJn_tx zbhcegU-ke~r{Brd*KZ+J)1!7;Acc_htZY%F7U8yJNS4*Z;W7)hhoLD12OPMXI&KKp z;e;S8Fxi!+#tgl+4n23c%sWR()w#A^=O%+{c}1_oCaaCG<qg}neN5?U2U>}uPH zlNt+>5}UI-MpTKhqCVm zewXu|w;PFnA0r3toa~6;x#N2A0%pd>7_py-a97VUzHD0|kEY+kmfhE&ZU@sVf5a8n zH^9OS>9k9-5j$yzp*RBJ2nUXoagYj%e=~#42TvywE1@V>L@ZQBBoak5i}t!?dL2|9 z#Dy>s2wOz0J`QX@gPnU?DH#78-Ai7Uf&)!Z+2`5EX%{YQqkUr&*|t+D(g>)e5KX3l zWKtpJ>VP9|D9L_8h`6zXrvO*s#hBOT5VROu-NFmDao=R@V6M5=Pk@#&5N(XuPiUk7 zZaP%Ro!>Z%wu2ua67JxX${(`i+Ksj>%h(E3BrKR0hWD57)R9HBcFv}_cOgB!RfJt- zT*x8}SlL!0;mveet@MTaurl4mGMyC0GW0kWeH0O)kdlI8=slJ3;e}Xjb7-!wCfo2B zGna3LD>?xe_I7hD(@mc%Wcxxu*hU@O7FTWKzSkrp*rzf|Vj6_0h{)N%JMnzbPXIs4 zqc!-+EM0H%BMJeRnrrUzLtDoPevH_UPil~|IQWi2e*4q2*thS)SWYKrSKY}ar@cb@ zSk|tovp5ndWcrUPc&vU2J=@MebS7DA0-a&Owc6=)o0vIe0}^W{N(T$84iKAg48ksS zI?$Ab$>K$Y#w;j_z|I)^UYbTn`%)s#S%|ZkY>P$H)9q9g?(hm=+~U&h_R*8gAkF|< zto<}RoaFXDi{XW_hJBNKm}fE#&y`k&m!USn5cxy6-NK7R`-Ep249ghAJeMK{koP-)pQ&m~p8Ry5*gMS+e*}W>PfHz>>i>s{ zeV7cdd_@p6&jzs&97vfGQgxDuS6i^UnDtN8pqZHzls?7sw{L_=aPY}W);_)(tLGe? z>|Dl$1;=S28(m9p;YO}sxe+?Lpd9L3dimy~4t7$Nd5DcrbE{N=#^auykhGI3{ncZ| zC84BAWD^=c%I;myVHFgS>5J2Q>^zRiX1>8zs-`N_UJGFnArkQpwfk#Sj%K4PUO>z` zybQLH?0a8jo`;u|axJi1MCyQ!9ChD6*EnF@ACdS*6&5>G0gmHfDq{%!{}!>Uc?7|7 zOeV)u@oTCJ1N9rdkW>vwb;tX)N)X#iJ3TmaJa<2NB`xi5qo8mL%Vs^uk#%LP-+c*f z-RF^YW*|89IlGxL{s|&eUS+{KyEtWBE2;V(lD5JNVU{m*6*OgdlAQ#bIE|p`G}4%a zWU@cq(%Xyb^R9Y}?YDNY^s3FY9NvROE7*JF3_kXbojg0V|cnv4g^0 zN0jcZz^@s7)|=o(sMX>{t9=HTZKV9MfMj>g`DYpcZ_As!XF7bwew>VD1@N&PVs{Cq z8+&7|xo15|aC2ys-PNltgA?SrJ0=DrF`0oVDaxN{jbVmhGZ(iz0T$Qx|L@xeQ1tG+0SH zt?33VQm9Q@?W8L46wn~HZ$0$OL|dTVf@Q+3buP6ZYG!})9xQhj>P(?vLYTS7va~ub zSqIcr=r;#E@h-zxSut_o!f^!O7!I}_%fb<`!(R9U7)lB~CTps~scsKLuh}PjdS`mQ3dq%MiK@B$J>Xv73ep8nE*K z?37Sh8KSGJoE@*dm4fgjW=?3}#f=xz+p(0=vSJFYgT%@nrRt12PP={&I9X^>Xegqp z%%Ww3BW|yYwaJ3Ca`3gRJf6-_T@A^Z0w2jLamP**B3=c4vde&-tLi1edR1d@t`P{i z^mO+UQ&6)$+Y5!#YzWr{x7{I?a@8{j#Mcgn66`sC_EB$K!yx?c?*zULd=W35^fA2P z8h^ANFEcqoz1T zD60LM{eb%3pC7KOw%@iKs*fKZEBt*FIv>#FXk?oPT=m9G}=e`xYr2xM%ljcN~}f~M;ugov0R9%MUhy+LSJmw-x`Ue}?wTVaL;2MY;@ zi@^$^*?t`*@G_(dACS`4Xf|(dx#qy=Gaz3vFeis8>^Gm!!SgG=nZ$ksFQ5Jl;C$d0 zBViDEie=_Ym>H-!Fc2o}j-mbkTWY*f-+lEJH8J5{pA->nC0%O~2b#QE>jD5H`H001Js)6!%S}uX{dGnL@U01>Hv%(zN?!%K9F*XT|ryDpx?Yzk_S7MY2&y zrd1^2#?z4AyM7#QUnRXx50x{vLo~#$1Ld@JO(mnsvF@w}ybCI#sga9U=<(W~c{aEh zaLmmkIS=EdwE-77=*GSSc)N)F2Z3@>L83bjf*h)5{Sn}41GudQ{zpW5hq_L_`1)aD z|A#~J$HxB#Ko>yY0}msqN)k!ndfPf}QI!@d;zC2C#S@QCWcQI-WZlUe>^`5!js=tz zRuT<`fDZaHduY$J(d)D_rEn87W;POyb>oJVqMjb67IaZDqt64JLk^tcLbm~Vx-BRU z!@ej@TMN-+g&aLpMburmJ$EMKh%%@8a zGj;ZB&;~npma(&W8s)L0WVDw7Ku*Pk&?Wyk&#T-#v)ts9VAN23ns z#NQ7|Y&oC*))0C%@Lu3<5y|A%yB{gpCC@blaL0)JPl4D2b>R%b`92x0SlJ zA>#1}bQJNcug&G@O-rycbLs3_MnSlg(t;>rg>ZBSv4VQCJqM{-{35ef9D(s&aG;C2 z`aaH2I)K=t&2;oOGhs&-ZHdv+Z82lwOH@wn445l0 zd`$csijUzXp)>=V@(|0Dh6u02D_!+v1|sT))I3I#szYSZHHpY0c!@dJ08atia?)H! ztAYpT<^)jsvGJ!M*@F@uzm`;OmDttQAr7BjMAIwdx#PZ5S-Lhq7r$InvZgUqm2kUVXp z{5y_PE2$e3>eY*^Rl8rB)nKb?Qh=nWVoMrZ(_WwqzA_Ov0X?R3vjtnjY-*iMM^^=z zzQfR)fu>-ccW=azIu8tvu%~%B>Q14vXD{Q+b}*&155||!{o)b|V^R9Zz%ua}T}tf% zvd9}}aPq@!_7#Sp3JmhzlPehwv2Ha4o-)V$6L70}Ka-F6f@*vKDKEwIrQc$-gRfyK z#s-M5+sx%u;5CoUe2Ph_eUp4>ug-@~L)h}`6@Q3KlPQ}q3RGgCLjH-tzE@`GK3&(o9 z$k5Euu49x=c!)F3Y-ah5S;Yq8n)Hc9UypJ7$RMVpdK_3AM`~*aRok&VZR5ic4o;cD zkt56L&4h@AHo^X5(C7hkDf;Ulr9oG7ONife%wliz3NmgCEA%|es`rt8%%vkN9ByC4 z_`)_O7d1ht16ZOdv9Xa6=sS1Qj&>89DAaMlm3Wm`XBstqeb7Cns%{VjQ`MJ%FBuJB z1zvNBn}GTGuANiSCW4nr`x!is5nVw{ zPDQMP6cjwm!i8&?cFGawu?A0jCkMt^#XaBM9Q!sMvN&NmdoGW5!Hg-gO<}@Jz~LjhwpV1)Ezs zZ9%b;(d<+ z@0vwqKVB5vc07N1A@CMsA`}~@0-ihX`*_uFT8AOpm(AQR0}^=YxEC7KeuIeoo`D1% z7bp~w3-IdLjW^e8F=oLJ@%r3pAQ)+Eyir=B>OfW1An>G!BtX){u~TVH+BRYX$z+7j z{B;%GZL2tTF#bCSKLh^+*I4Fq(!UT^mpbvH8#H=G|@yUP&d0Z#$x?}aw19%$ye ze!A>C@#5-z29jkUrrrwtY{RRB^BnMd1_!$62;f>6FSGBxc>eySc=hW1YXu_-ee>}} z!^rLiwwm>kTtEDnbbq64_Q6Et$BiqBrilyOEOtXJ_F9CQ+l%>D@>~w>x*2smRjqW~ z8pf5^@wL>`u)B*R6T;l`fs5&E{umK!0;4d`#d7yZCrTidI)<}XJ`c5 z7LrLfu!906VgtYHK=2Nth@J9|)jOkf6-?mq`pcnrC2n{UgkEL(%wKWgg$?jbkF8g$ z*U}A1RNG2bQmCj9!h0rg=*T66tx^iC-*fuHwXnUL^fZU$-V*kA%%@QMSX{B0ukLpS zW~2e;Xpba<{@V~)zHGI8#GW;qe+zjc??y5O0!VsBM0Nw;GS@!~uQ+A!`VwG34S2El zPCjiv&$WZ@*R0V7{sS-0eHmURUXURKW8(XOKj6jv_n#+-;*Z(wZ#1#%Ajui1#-LDF zC)Cuyw}^4?zg6;$HJ_ri=sbi{Vg(zx;G7>2owS*Wz0GjY<+;{E?*7hOX>0y#!jV#n ziXZ0unxAm)MF;u%7d{Bi1WG3UnhP#J4CRg{s$KJh^JY*DsP2PHK6}le5J*P&)WtJ)1~Om@?^iRLpvmvsUaNeQSnz+VUA#BPexiEoyd#dFl@nsD0r&!kHy> z%K-|euEQ-lWXtW2ZQIbeaXf`grY&=X3s!95ckflWEdWxAWtvG@eqFjDknZ0O*M9kU zTw+5=)m%a!gouj>s@i5!Wb+>DgL46?I;aoaFn!)NkBhJwcyyS)?D^&pj5ZvTNa{PV=D^Kf-9g~j)<;w?XC&!-Pz;}TB@S6mn4-diuG z?dabS4OI}0JVgA1xAK;^9pRf_noC>jIfSH%i!QB$^3I@!PRgg9{`SWzI*?I^43hXE zlb{meCtsOG=l1i-9=V(-<%A>Kk@3Ih@=yPk`gyI0WYM?D=k-pxC1c!|W;Y(pF&iSUD$7txhnN;LWyIBoVSqx^V#*izpWK+@F_ z{Me)EuO9r~Fdb86I8z|d7$UDS;1hp7rjk<`c1RRGk16Qp+=H7Bn9KQEWY@~Ic$99?-A!*uy#>-%2)W0+nyzz=+xi;Ef5!p+abR8 zjcQttd6Z9>>OLZT9{guv#GQ z{WTejumAHL>Yw}v#92Xc(H~f__$rEdsQ>^V07*naR90^K)bk{y?}g7{-Tpbau{JLJ z&-16SYtQ@Xa~BYay~3i?f5hrbcS57f%iGJ?y7}FN!bgdfZD7faZa{5GNfO(Gx9k*n zK{UoD7o)%1zzrcgc|(L_+l%?xv!`uEZj)-Y4=l7xrr&CILd#L z@#@KgViQj{pptY_wk>v5Tu2@&pzgkTyzt`nM75GWHZyJUBQT?bs;aC_O6aXj(dc^`^WFGx4X+;z0OUN1X5O-Doqe~`>nQg+M=RDn9&`gcyfr& z?Zw>r%NokUr_tGQ9*&lfbsWmZt))Po<5Tah=fKQkRCe?ss~xZC%xz(oT_7Z?98E(V zDOn9J`i?>#du#z4p1YB7b}`xP7EWFCU!3=;jZ{`%dK-Za+|{zP>pojmE~Gx?~S{vP8Wk7D;1KcX$@vJr^>rpoyuo zzQvNSZJ|msHi(@-h{qLb3i;8uKZu*T28VXaCjEfR-?=U>>ArOBV{qw;D4RB}Cgcd= z*j|{}OU+(v8*E#&-~D6jK)B?^XBP7KlM^`FHia!eyMWfNrAT%bO>P0o5kl4z%%AfZ z=gfVT#pkrb^bGrtI&4`-tzuQwA*-7nb6IwSMJz+ zuRix6OP6-lQfmouc;usbr!it58-FH}eIVA|1L?Z!mF}JrGOmjoJ;vY<5I5QsR)(>p*2Zls4v<#Y^B19{L{?WIb$l+%5IzIP*~xl7@bbsmzVk!;%a`|1G4(}Rru(QxfkK{q;##zK1?p~N z_JSW!az#5S9Fl7-;)L0`vjC)$NMRB4tLypFm*z4q7D9R^u%)4fzMg6dqXoEq6X?xE z>9Lv-4iL4rv3St~Wc3K|y?!6lXZ668UKJd;!9&7oRsFbGY>{*oNCIUb7uJN`+66q^ zG?y3Z{{~AI~1eMd>=6qLvdjs)ClLlt_+mm-=AcAaeHmVEy5(M04O?|Fy2$K$PRlW)E!I=`XL zc_IMg@!s2L=Q*17IiYjoL*L&2EH9VK#!h{98teY>c`U93i4coD&bY$IaU6&82@zVF zK0={2gGGzJM*L$Bvt=VZv96q^gEPoxCQwvdg4=gFlC7qP^&DxQO+hpa3SkKo7Nsz@ zn`rD6I=kAKIsbY7_Omb2>2~3|-Q0Coe?dh}CIf(20QOExh(}@LL z$!E`G*KhGaTz_M_rd_ zxSPUpPcdO)Jsq7r6c_K{ogZ$X;rt$!;RRTx(=LsTiftSo_v3XE+5r5Gs`j0zh7h*E&c%{>I8G;hnU~3CGN9wJEX8rU!QD#@}Hnt?CKjqa6 zk4<@$=GTx*$8AkxD>h!OVj>K>oQoxgY=+H*PL3R!mpi>iOfdoFqf9_xXtgt|t< zIFhc{Zn<8-64w3xa$I*JPWWEVt=Ucb`3#9U?i-Y*E_jI|7XTkJ(mZ0`4_0Wg7)Vk9 zuS4rp;9R`A@!v9p-*XZhX#!pjddx_U#dvY;?i)c!hG0v851RW}Xr$9KqpU>hj~RHG zdj&=VS_S;>4Tzh!@Ur|$@#6QF;5Dkuj-=THt{VkDZq8K-oGKzmd1C@De64NpZ)UhX>P706nT<0kNkfO?p@|L z0toYgcim3!u~U%ZwY)R=BA+G0=df6YYfDO#Bn8q43A{RPwlOHw+h`(UsNU=D3M9>> z`k9kg2utkRDj|NLfFIwzn9W<>MW*u-pq+`Ozh=>lhk5txU9dVs3P@H7>ljGtP!;G4 z^Z&lEh_<$~iAE2xe9_~SZSK@W-1VABqyjrKIZy?4l-*g)hFFM~>%Ba%?tVv?V8aAA z=YjjLtyo-UhTZM<0Za0 zqfkM8GvG8cPuCFCDR`Z2Pd27uK3>kf2DFFAMWkgUMi;ox6?mD1{4 zMf$t%qqpTVgmgY)H4-iV4)6WM8m6Av&a!2$jcU@iQBq=Ki=-1FA!SipYe6!^eLtN~ zSNl7OMt3sr%;#WAmZS;a?XRAbbc6LN2H^c%mCMM_QMg`wf;on_!q-QNCmIWW+7Os{s&Pr2jI}6WE(A+;4aIyfFu-GQ%mv;uOhw+~Wk?qpO(wVpEa0s`>;xsdg6&0y z)sPmOw4dl<#TCm|LwIot*tP3jU==d8at)WhXP@zNOTU(zZcrw6g!sv~FTiysf@Ihk zZ1m`*q5GX+HD4Q`faDhS!LTH--GwGKj37kY?oz)0cQx$X@mm!^oblUQ!SS1ZK2o===$q;8(L^$wj8C!Q>2EtQZEZ@&)$`K7u< z2j|%5&2Ssqu)oMib{Xul8>FsdklZZ>)!oV99%_;YyxRagl1F_WZlgyGFcr>WI6m_@ z<782IGl%nx3Z6UO0aONwdk6$x@H*F)1vDOiURRzdA(>+)7~5YN+*ooUUiF;dy!;GA zEAU>S!5>O;2r&#)`P$5hf^+5?1%@$}<}vBc37AL64OIAoL~Fn`K4Z7eGnS+hLe$?F z0%F7>6G1~{vu#{c&{JD07Kj(E#yZz-Oe?eJ30=n)lXXIrm2GGC>k1+CcSn(a2_DE* z&zDt_0wnH2T&S-f$DQAp$0u(52VUCr4}_$ONaQspP5L^gU-n(Xv8~*i?y}_pM@5uG z!ZRIgfP|YS?NXC;;USCJizD3rFK?y0XC;NjkFnswdeU2ZY)J)~(4eM_!x6f#UYQ3i z#%1)^W@F?^a4{g;PJ=SvU|#OjurM%wD36$Ip50s%o12tEvQ*kQ5x(bBC z7V(OHx`Mijn{nU)o$Hl!R+|)dx=N}V0Gcggp$NUsVLtzvwRpMdCa}%xd1feDFn!Yf z`txF{QmbgrE=oZ|GQx)YCUMu#ucV{(TEaR9M~~pJh9xV1#C4bL)fJ@G7#lsj3s${64C?$++7-ZZ^_v<8QqculAefY_=gcnK}3S@M_x0 z5XAa+;C|zKKW;5KfS_NSbI-x=oN2zZcL+11%m7j40j3nRTPGN8>46g@_O}g#jl3PW z&#e2P29$vyr%MgS<>|GJ6`3N#YId1om0147A zT5B?0Z3FZpXi`Zbq*Ab|vV_eqy^XexQ<*jU$Gr7Jn@Qc(XTZbPOK{1Lx(93Jzw^Zlh{$B~IXTSHpJTBhs0TT|uDD`wiK%%!9sY^Oa#4M3xMS!vh8>Rcfuf$LefbO z^~Q8)Rj>S$7n!y;ZUGK`S(k7~=~H1riQY~P(`X#Hp|xt!-X1|BLPlL&*YRp4Cfq>d zRCyghC4+Ovj3yA}j$13(v2!jB>px1&T1dp2KsdUOk_itmbI~JQcl|b=e?G(FCkJgB z5d-1*>U7c`*mPA%i0_P${OS_gJ1$4Uo0z-schog>V<#0G85FTMv>ekB{KVxsXl+os zH5&=iHB1k*;Q2DZOGf4I91Z{);^F&R#;knWK!CA?jZHWKui?Y-ZqyL)(IM~&|GlWO z=^w)@cbywd^L~c+cvwc=KHcCJ#{*8DZHUPY#53fArRN(ayg#76`}260(XMitWNFKi`Q%15?vdEz2)5Wfp<{% z*cW;AaL*~+uP>_%b_w~CZ5=K(c#=I<^cvyVN;PvLXV741}CS+dM*GM49=L4 zM~YbtMDJ+;L^}tsE9}l;+-v#fakbfm3}LeeCF(ozV!e+B=5@soNOSHQxCWJf0XUj( zo$?gLoF4%Fa!r|BH|XhxF>DHaw$K1)!x5qXT55jUMyYZ5>s4bISuW`5SLJnbn9Q}t zoXPlg{f09g!UO5816`4l_-JxU)J`H*IK=5=*}1=>;iap=SY5 z_a)}d`3-XyZsz0Hw!u*cPI0vYF^wGwwq%jP=&c$7?LAlK(t~*lva+=$~O-OuaBEEL{kL~xQebE z0#FJPUvL@yCKY?gY5hFRy%%^6oPKRa^Qhm!>(4b(d~LvoexA_|B#-7HIqSF-7$+jl z!OwEEBmeyZ;}*-cuX9y!e~xo}o7micIq=~;jAtnAV}-#79y$)JW4bXRLU|b7E`v1` zy%*)J~wbO_u5fJ9+h$`|0ldDEsz)hQIpQG$vNu%Z?qd z+T4`|EX0ixvI?=Jr~evBToZ(50bvdvEaHX_UQ5BxrjY49i){7`BH<{psG#y1q1Zz# zTKphi{PH%a&%$zrPf=3`$r^+jKXzREqs~lz`9L8^2Hbci<+XrM{`3->4}FY+g4a0r z(r+>SmV=~lN#<4OPAI4(N6O{{nL5Z5=4y_+#zba~rUzdyUT^dVNBG{#fUPb$Zoh|( zE9vY!gn2dal&UuSQey;nQWaj5Tt18sfi;Jy_e@aZ08@i6A}2+;PV&jveczIBL<|If-YUy9Sq86htDpx`?j!7)mFqO+>>l zGHv=pXlN%l{opma>4lV2i!G0V^(tr$uv&18v<~K*`BAp6ouitWs<{z<`frP9ZHi+_ zJ9Fp!l$E!<#(W$F#P)kki65t{%1IYYlp*ZR;X1-BF+ZcNVA*cexS>i`7a1v2#_Pj# z+kiWaE9%TVCdWKrmWZ^EB;7?Mgct2LZ3y4{&y3n$n4>-~FxUS>08_|ixDT4|&Cda- zB1RJa6w|bF7{C1)0gWLj&>Ud1F+C;)&TSduJ{>O<{Bq+32x>!WB1NJEMXwwg&f~{qnYEc@Nz384>SKQUi$B_`z9Pn9qoV6UcGi7k@kJgg4W?k z4F|a4?yqt0z1!*Rd(Qt+}58O>(Hj5=nSRx38K#0zs z9txt3WH?Mo@p@L){GQ5*M_F=S8}sM&+228uh^A|?h{H5+rUA*-e&wv!1>sl(o2{kj zcP#)R9{R;}9)I$KWcub(Q1T0||IBmt95QwuR&9gnGmoBxboXsxD~~D~V)%ASQT1iGB2g7cK?mU#c+BZNNi6ILjZn1a-v5(!u?Ae`3zj$QMIGOTNIcj`}{Q2qQjf(eLPBR4= zK`R&*7|eaZrUUGyMxv1iPll6*xdQtZENw0Z`@~a zfLt~8BD}ctaE{uKVHA&|euY^hbc-n}Dkw=3H0gU@Vq7{$3H; zH)=Ir8Ps=OjFk60Qszc$(`gb3aR3Lk8(`%ozPF)-wL5>r#0le=KDmeu&%KYX&LynA zPM;(FF3~TzcZuC zdrd0iq)8ALZo7Rl|Nfhs>F&Ik^72})`_LVne@!zfuTTdh)u{Q#&iA@po`~I^XOZ&( zKYrzvw*cFMwwo_B^AO7k;RgorS(#@}V$|j?1Cn7j`yvLX+WHxCZz0fgJ)b6l7nv8! zgHv?lHKgprE5NZO;QsQjlMm2yjScz>ABZ%BS(0O0TgGOcZoowHxG&FRAfYDs9J0Wp z=KVk5RjgT%vsPt>IcpwfTx*z6lTqa>bHGl2{mOh)K36j`e_zDl(BR;6?*iJ4OU^gj z+&ufF2>2_&9!|1*K6Ul!ddSV%FZCK5QPdV}J4C`(yDs7Cy05m@{xZ9nX%)}&rBBq; z*0zs~)5o+)O?#ZDf4j}Z|UQ>$TgR{&6SbjeL zi4a&00qo5K&iN95a$u9^a}&K01rXlPK*(jTFMS2@Jp=sQ4g4*GJh1k>G)S3o3Vm{KDHWy6A!d8P@(38LhZ2UD4lvv&Ud%`}+AJ$UsC zuFLa$e4%|BFlRW)ZsmKxfs8cu1^yuQIh_>ZNsC$)yGB*gxQ%w&c(Zq8$+}sr$YdT} zdl4*^l4WRh9I36v5?efGbU3L9nzBja#pI>??e>6+Dusn= zgr{ee@$8zvCewR06Uv_AU03~>ho`q{qEVHUK_Em8gL4=N4fP0Klo2E7hL-I9arYA_ zA(f+=*8~VC_;BtuTH4rw^~YY%VkO=lM26IQ;=gf%Ewy9|=HUp8e;+qpGfwU!>w#=KRVywDx)Jgc#V4&%vk|fNp z?ylqqU;PX{T~|?9_&b&?`A<%}Y##}6Q?`vLmM9js?U^N#wp31DqUKZ?5#zGy7=qBI z;+3W<&0!WBR3V{zO#lE5AW1|)RP{llRwUM%DCZF}_U)W}VBQe#xjD>Vb2j?L zzLt=uM%u!_eFW|4F2ksT%wb0$P$l1n^~Wxw_6L(mcHo5!=Z=u^Cd6{g3(J@(u{?7V zy#Dho0r%H_L+qJ6IExivN{0h?IWL;Cf~Yfo)`fM zxcgS-1mz4Sb}v=ka1tFY`ip83R~5DG;Nq(4>ZJ0eJGTD+?VW3kRn-;8fBVdf0fsk) zih_V>fm&--3~HRAwn85o)Xw;5ZKGHtelZwRi%ojYtxZj1)U>u*N~$Rev2#JIAx=5U)plaqg#w6KkTl9!b&LZbhig|LQdGY~5 z885vwnZ~s@(SG1c#@DQ8>DTY2t?F$~KDnzHMMo_#PY6Q+q+VN3sgQ{!CHx~OJZDMz zR5g6sgVGu+@Zu6+l&@N-b1gCm!6$G&t-EanWHz!e9?cGqDrwv9U-j7RCo93?u{pby zLG62qfrmU}yK4MK;Gd2bu|v%Q_kMcRmcPK6+Xxk?zGRgawbAXYwt{*lKJsV~_WSD! zoY`}0SHb4qOX}5_3Y1x0a$_}rmg0}(^95t!XeY7f+%tfA#+Wi+*lb{J-eN&r{TV+2 zz@)B@f;LAFa+n|ti|cNZ-56-6mkFac&~@nZw0?9Jr%?%F=7Fgz)|NyFLoI7zOc)a6 zg+RNcnxukFC}Kjvgc6ygR(wOAAfG2_Y9eUR5Zs1)spN;Z&gEBk-%3Zv62?q=f~D82 zV&D0@C8`Y?8xn~sp~OO6uTWo*G_)}J7x%f}qhNF^j{OixqDHGn8pn$KE(7i{#+(jJ zG{)2zW2%5rTBi83inP{bM0}-89#|1$X#q)J7-ttrBFzC+JJa`0CD;y&U}KAjsuc(u z=R~JXvs^~5Kw1wR@Bn7Jt;jO**u=Zlr|i?3-qAT5XH9E>3)7ub#~hyl{)s-9f7sp= zecz=5#aCM}O-On+)i^Airl10(%XvoWaXB!~VZh5Mr5n5V)1%L+Q=dtPer~s(m&^3V z8f!HD(*;a9906JDmBSL_CT{I0gb3jddYN^~3+!q87M<~hJonh2kgN98C)suJmJ;F~ zr|X}?#9=8aF1>Z5&Tm7QX!r9QaUw4+|zyQgs?_No+> z_3a76uxQvLHTf}Fm;krJgtBsvKEL21W$@Cmy<-+YhvW~O>PHd&q?BT}jvKYW&_$UyxF zC(0wyPfvKCs*z0`XuFv8Pc?DNPf)*Qe%06Vx^ZChdE+;Cp+RHuZN%Z!3HET*IS}PGgmnLue8#JUx`i?0-d{ ze-6l-h=?dM#^iEA)Et&a&pbPppDkO!Z7c3&*RGoxS-q39&$x~AE_sl@Y-y34iSzY( zZ(&F?^Nvg7MSVyyuu6%{XFJE99VB*kBsW&tsQc$(N7b%#+-)|4H7mpSvE!{0TIpOj zQOcP&mP;*l+0C$Aa^vVe|G?2u!E!+N2dciJ+I8%51X*;{&gPVk)C_D68pK7 z;`H4lns(KmT;@qNiTxx6LOSUivGJ}9uwcDdb#C&7L2Ffbe($h7MH4p!Qw^mLw z^?4Spr!0aBfikVh?#b&ife;y?!PM3gM72>s7*_GzU9)&}^>RMkelO+mNnl>)tc7=Q z-Oo31^YZ-?MM9_{4@6*0p^z3{*9#G`z8zR^(AH4J6ufPXfk`pc!y-+>=ev} zs%jTJHuS%hP~!~H&mMm#Nwk~sAix!hGX#naRC#u?L+U2$bp@LPxV@+XFa-nKm03DG zK|pqmy^r-Am*HbwgXzjxRT-e>3qZla65Q0#4XS;lufrrAY4&+An%JsEuvvHh*t2WT zC5e}xLn-{$*Y0a$%IwayQ>IOy;ro_J;Ue>6buDV_xelAz#fnSsQ^X;cl%yt#KuGNG zOU^pP?JL(%Gx4_^?)egLzqyjD%1&&WySXZYLK9kAf}B9kC>AC!<|f+(3QB?Dh;d^ zm=VB&a&{IfSl#$3T3aWvYj-W3`z|7gr%(~>VeEuwm^J%(uD@vqYuEOu8_hdQ=_5Vm zm@<4u)sSj+k9|{tl<_37*MUd1i5{Kbl1#VAO&a5$Yyt%+VI`7rMnn4mq((x_^R#3`Bv**_B1hLqt z^;Qj3+MVGtwZ|I++^=@=-MaR@Lzv%8J%5U$Wf|2BMpDX4$H4?15RoT_L0{8J>IF`8 z;sMg@{u+3ss2L2K*-cU9)`eGwN?NI=@iz(>D zNovcbj2xSOs*!QB*yNb+?!FRJ*5Aw3&OFERi`x&ZS5?DJBGQ@BS}4ue{xYQ~EBn;~ zye1+YgU!*g38yNEROzuReV^06Ju4#nhtZn(d}XLde7WDwx%R!CuR?CH03EEgd_zRG zyVuPaGf>0SoIGdz`PMHVyc^$13oHS-R za%Q)J{-hi{dCMF-x%? zwKbOn$v0N{dr7n8imY~=m_11A!uyI!deT$PKt$f;4K zJk*xQLuI@H#v3-n42Ox2mf;;u$txV~T7yU0&pk`X_p4(lGTegz!5|VOd#NNn-bKnu zJGT`Nj1UD863{4Rhy)Q)5JABDs2tfdORzbe^T}5-sLfAF1l*ezp(Cr0R2i)^98n+qmagp9A;$S@RptxC8yp zcb&n;9PIo-)?i52a@S=j;|=y)2Aik5?t`u4(bCso&(C}kn_a3pmv&c0>#hG60X>s` zHm0=mq8VsntKt4;4aHPk5#zRGGFD5rG>X6YcaA6w_P^nEB&~%07w7Q**U`QH0}&Xo UnYOOR^8f$<07*qoM6N<$g31SVF#rGn literal 0 HcmV?d00001 diff --git a/doc/images/bnp-small.png b/doc/images/bnp-small.png new file mode 100644 index 0000000000000000000000000000000000000000..3dda0ee559cc59045b34669df5445ea2d2781ae6 GIT binary patch literal 12497 zcmV;?FfPxDP) zaB^>EX>4U6ba`-PAZ2)IW&i+q+O3*tavVE$h5us}UP1!67sGNeMtBEazQ2=IdU{f; zyQ3W@MY5R8#Bhf*pcKyk{Pzw2#fN#ZU|w74r96CCx1I(s+VlDN{EUyi-`|fuzkiFr zUJuE>e@I-){LG(U%lY0vxb8n6sNehJ{PlXM>pM>S4)p!Qj{&o0PQ2LPMScg$>2p1P ze-AxxPN?n)yyPe;Ew@NU!GhR+1ogDM?eovp7GONfv4_)UwSCXzIUZ-66 zZn^$r{K8+M`$GWp{B!)C^}Qqi4991nAKNOw#p!!J`k&+XM(OV>is^ex`uPWa%=zQV zKg93ucJKM@ZdeI~t9&>0V~g)AZe(MY-^U7{l`rRYeLp*&?HSJ^Yn!dk@JxQ`F%j7t zE$%qtjpzNj;$n%GPOAJ&X{G#LYq_SL)|(X5XS}Hf^0RXzAB)UoE~mfN65n~d=3}zP z%sX;tiVUVK{_+`qzwm$k{AHkfRpv1&n0dajVqHn~#WIvRedjI`67M&r{>C?k&yVN# z4<$B}L4U*C*x-2PXNc+XXKm#-&-wAdt@eESCD#2NfDkdavzVC4z*A!OSmIMtsxcDB zMt%k>PubEW1JQ%0%_@`STu4q!nww{8eD6)3K9_+P65+H+B{eh<$ue^*^jo8ZM)2(B z?!~J&?>ZU?{VUxl#U%g<Y_4rRL01lotXT_zjF68bN@7Nf$09Hc}suFTu|!%KQb4T zy1!-aAM^HytnKl_ZGID^E+nQPWTySjA{U$cd$}v2!BB`_LqDU5zj-fO42 z=k$8-c>3D+zP!d1%`J{)^BB8k$g}4lrMdx#bFW-D%#us(FwFE+$J}@Jk#9NM!Q9Xi z!MyvXc~7(-6Tx$zsor@&TPN0g)_uB{`+()Kc1D9Aef3+~4RhYP=-XM!jBVXpx)BQ= zCR1X>eD)sH;(qKDU7zRu!>Bu-w5=p6EuFBMQ&j!t*ofrXmr2Z4KZ^UZx6B%0h-r&#Pwu*Pl=b0Vqtu3<0#U@hO z2ejWO>_P4B-gAl#blqjxw!=qRlc4Rpw~sYS>OC4o1qz?%fPtkiKcn^Q;@onc<=_tak=+>W2vDiDb%` z)yLe{S((xu)FzF2SvGM$)(EAQe4JL9<-m%FkH+|A zlDB8qcH*Smn4EZervG9rqtlf;#@{;N!#Ani; z!$^H}wwhxrp$ljT{A~R`WADS0Y2DT+^+Fk!Icuf{5k)nRQM<=Q%%XF9W(_0*`j0n| z+eohKNtXyN(!>@B)I+sWwCRp@crZ=#G^|bvc%(-&Z{=uPril1sm_T6O)BCV z*nnX4`$n`L7`V0Solae#3Cp9X@7!~F@)^?_P~8dCjIy=Cjn*?cBF1lA>o|Hk-$m2f zAytvwFYc+ah}8R0XFUoh*$K>g2_jgt&tb}ZqOgJnUr(oA+F*ZRd9JZjF)b3swb(rE zkjJ)q0i`n%s&=((rHwE_IzG(PSC)=#d4IEoLr#`wtF+BFPwR*NIp# zVl&nGb`xbCZFQ>T0=z9xu&6Hj2GMW@B>*|yS!n1ICY4oz3PV@swh_LoW}>GMCe1by zs7}bBR--7vm;r2*N~0IMR5yK8%Lue=BX$EAA!v};xNTenVs+&KemfS^?VbIuZ^u4$ zj{C-LnyQc*c>?RK)I4C9e35jdCmlDSIUbJ*G)G#_ADufJv)GH}tU=~iq;Z~7LX#n;IKW%g;7MaU zKyZQAb6Glqp+Wrn*x+eFqw&qAV^N&r4?3(uu?%wSqD*muEpE zXa!WktdKrz&uqThDYOVh6&0JX3baXZ$|g(@tazKXEnuICP^h++x1Q`+A8rHzBf17( zc|;%+^)0%YVN=GLup6Mht0afFsHqEGLP-s}a<+@uqq&)!{0=(oGKqEIJg&*TNWbi7 zi4vY6MYB|8vqMAAFgn7thaT5+c2|#_3NhEE*;k(|1k?cW5jKI2S^y1X%!qWhShJSX zBCydfY6GwXS~l&`d;kyJPG3`@2*Bjx?n@@%emI8;BD?5O5E*3yzMysKjqWfMp)w7feJ&?%BpzS&%_-&E10g*=oyLq<9OROcm9fm#Xi;n3-r(qkFd*l??a`eY{~7 z`-rgEDoYEKw;}ZIGIQNLs-@ zLbA7_*fXg@ZtYyo0eXiN8~Z;RgUMh@ktEt^vfo9k0WBGR)eDkDj~u#%My5TFzM-Ze zOZ-Yr!MYwrt*~vNb0q_xfV6f{yDXEom3E7$3MW?&5Dag#9_W)|?osv^j0rd~6X|7h zYDO|Pq77c9l4oT)$`YU;WVw(7>KLhv0vpnqLUS;D9G?M!HGuYkVk7=+3lj{4J55QY z!ig4Xzya(#a+fqwpG2C_r}J(T{WH={VM+C%8jLnV5sh=P;2;M@De_nD5{xiulZw!! zxhmqP0Rk8&N7&-FkkgZ1JGG_>^+8kY3h;<@p)qRWSNHa%pe6Iib*9lg*u?m&@3~f0ofQdsHB0W_yJ@-d5Smol3UtX zJdp3<77y6`6UDzvDUG5A9gfMEyhD%|05>fqtKwb@WLom;_72z}OMoyX{Z)Pn$%B1y z$0sEnGBO~%@*d)?E9VDJ@UpFj_*@xo#0XNwSkR`E39f}kBoNX9ib=Qu#*O?XHb=8J z!7$F|!IPxCWdn*1i3A3Cwg}k6F|C$|fzWCwvh|(}6G8z!Oh+A(7s(;mrUU&6Y14|% zJtUgz{3s#lLP7&Eik%!TYM_QdP2AQgY2&I%)qowF!3+dIb z5r7YrnTm(Iz>czIlE~tcmBHLT1Jj&w=mquAjz9uf+@t}$N*zU*V2iz=kg@138f`^s z>n5& zBZ3ELH3PH3nr0WpQq&G;S8NXgdn4-v^MH;Z4OhulCrY2o680;~;!U)NA<3VBCc0LH5HG;__M|n06lDm3 zQZ=mWC|D1qjsPq=qatX;2&;SOb#Y8l>!7|zjWQ@YR#4dRyP*tks zWpxH0G<~9_(XSnA{9s>jxJ6wrXfOEl8*EKElmk)@*BK7FK%RAi72q9Mfr2;cU$O_9 z1p~&T;|lyg!N9c^NzI8qareQLxb3`@{gQ{d~eA2=5r%K0DEOF`tL5E~R zTo1k@lmZbZ5_LiP#}SRRg#*|iAU+cUfYDu$iIN9uz(6QAI(5-mDwIwhH!h?Gyc!fv zwpv3Yv;?ps3guhUpxI>d07v$Fj%+n0NLv7glNDri7j!#%IXR#+sBvOK^?_X;Ntq{ujLm_z@vJ!QGTl=yIeTOkZP2D0 zuAr=((9&R_80yF_ZA_l=|2v?n;Ou2+X?eaE9cZTIhHpTbfm{t?aXzP5$W0U-nuOSF zDJ4rKJ~JD#CdyZLonyIZA#3K5(b_uid1y#OiTRPu(+Q=ars*Wrv}fe@ga5!@<{DxetGk)G3deX@1-S7=?GQnFEe6TT~C1(MW~QQQhTQ!)Wm ze#i-NM?52iX)r_(a6t3F^^kTnB+OWFS%WQC3v{o#o7?R@wYhKxQ3i6sxcErOP2ZB4Q?biYm@wUth($={^99jt#)7e>r zm=HJU477wlDNs0#08}8Qy_OnfxB0h`SDvS>iUWv~zG!d?H3yeqGD#3)O#%h^fdVa1 zhQtKPIZ+)nprb_&+J*3@SIO(>QAB%`r%e|nY%mb9sqP1aMx;Bi9*~7Dk4Sl;>j8l* z5<|zp{qIq{xe-owIaE-~rU0D8A`=v*f;`(`@i?F1;C49VRS&3pgg? zg>7G#e*&)ZAWB)95V$>vxhG?4RLHFUb9zp*BO&0Ag6&Eu-VVjbqtkwBFe~s~Zy_wy zQbz{D;KaBvNieJ(QnN~`0FlY!lK3R(tgz4C-s2s*I!dU-fBG~0dO=aAt8)bAbsCBP zOixv}V+NYQ97@28QaShwLW3oTfMCoZR1+T(^Z8OP`=EV6gQ;`9tdoZeb_mXRJ)~}N|cunuzzHbQQtV^>EPK}u1YbtL7{tZF&7>dWzB^}`P=d6qDzJO~8 zHAiP7;?w)~1&S3p{drFWeUEOv^x)z71ssmI1nFelBF-Lu#6wTLnL?`KyaD0b%R5J< z5OlEio$51@AfB11WhY)Tbe)o-R@mwUfl~xqf#4@3w&hjt&(-&gdjYxMrF|EMsAOq0 z0|-j-6onL%*6%jA(!0+!8!f#Wt`Y7}_999ZWdxLdODD zZR%|Rst623`dzX*$m)(Ln8e_-GxMVuBaP*Pgl+JX4x_PZn z2)&qiOQ}AK0004mX+uL$Nkc;*aB^>EX>4Tx0C=2zkvmAkP!xv$rixN34i*$~$WWau zh%X$q3Pq?8YK2xEOm6yuCJjl7i=*ILaPYBMb#QUk)xlK|1Ro$Su1<dj$A?7vov} z_x@ZxYR+OnKqQ`JhG`RT5KnK~2Iqa^2rJ4e@j3CRNevP|a$WKGjdQ_efoDd{bZVYB zLM#^ASZQNcG&SN$;;5?WlrLmFRyl8R*2-1Z?32GRoYPm9xlT2N1QxLb2_h6!Q9>Cu zVzld|SV+-%+{ZuY`XzEH!-L)(Wf~zbRY=8?23P`U?69N*N zNRbjs3MrFh%5A?tW->`gSnGFve(2|U?vr_P&zyV8`+Pp{Pq`xu%d`tUC0Mr2uPspy z7j9<+zsvb4%9)B%`n>TFrIlr5C8y(3HU8JIvuuk4`6ro}eLbUlUitg3oUdTBZ4hRF z7IBGm%WQ>&_*Djy5b4IYtlxL!d zSJ0?&W3sb*A%x(&@4n-U^(FIt3kmido} zh#)<^5$TQ6Nk~jU4UHo_SHocW(&ZEt6_J^hN#`z|>-2r{#7S1JTv>Z;2%@8-Y1E_< zsi~>N#m6Iu#%2TS)~@CFvEv9^(6vWbnm2FG@#Du?zji&AU3-j@B$1MwLbDdlNJ~pY zRn>auc;?I*mMvL^jZN=MdXt)(T5r44`T2ala4}0hUrIq?0kV)tN=fFT))z5y)JW3P z(?VucQ7|o&O27B#!P*Y0H{I|uZ_jFCoedb$vKzxM?1$IzbD%hn4M%f1<1NFfYG6aD zuN)Bgbw5(bj209W@W{gx@cVpdnnr4BLoCbY)QMC0{XU|i-Lz@fh8JFZfvl`d0Cwf> z)3osW z{b(+g#H2(94jRCN4?jd=VglIk*@A_<_s)CR!e-W-SwXNJJI*5yJxU;;<8*2?Y?y`~ z@N@FyDGb9Tx`CU{T{`pPD=(5(50K`6ynu<1PQ>r?^U#C`nJ{reoit zQ4GUC77C_qV%ZkUK3`6kE}cnFPeT%bLs8L918>0l^DaH~*YKem0Zw{O^WM&7e4T%Y zD9uH?Rxn151&#+RnRmIAp~vPw`K9N#jIGfoJ~0! z_}epoqq4G+FV}s^Yg4C%C#-3X(1{i30i6jCJ;vcfhlBkoDw?VyB!N?NhP8zj+F&6G z`d@Y_IlH#9E%z&yu3bi-e!V$%G>?D$eG*frOvNM6vNP>ilWeU{dYg$)TvW-%;w}t zlR0$Y5KWsmVabXmY}&qo)nBY)`kSvaV9;eKYEZLFAkY*ApW&xc|LIpnz)yL=hhdpC zjEtjilMamScqI>azkxp_HVF?dQgssRbhBX3dcMlri`Veu4fxS511SUoS(Z^18QZpT zxm-j?M-v?#O=@Zy{rX>uZQ96+Oy2Rl8emhy=0TDrPUPh=ZR#|1T}M%5R87M&EmXDI zq}BW!85K!FLINobQ|Qp4J>%~gPf2kp?gmjT{A>~afWHnPBv}e;Jcs5WF(HA(q(oY^ zYRP@~-%CE;};t#I#UEdVL<`&=FUUYR4mg%R%E75eifh3R~u}K8Zs6(QIU}( zCMJ@coJ8ApZJ7A@MB?J(h>neB110r|Zyxt=qP-Z_hqVBY4oh+5ZineY&s~D0N-u=+R@Ss){Vj zq^32bS+i!eZhaBg4ZDsR@64c6=T0FS1jvd)h3=;!;Qh&$g+-f`3mDO60E4o6qB>n{ zIdy=UxhptTQ5-x4rjCt`Wm_CCIZcJ`3w~esqMJsT8_AN49O5}uT%xS3jN`|T^Um9E zBPkN9s&e}sw^utDq4sg{aol;=ok)_zE0bQ~)A^sGC_znWS$3GmRXwL(HelNn78aqR zfE@&)B@BrP@2Hp%_X^$|ycl8rD*>9DLT&rOPPDFQi+q9$Y*8S^&0vmBar1 z`)h-(ibL7p(6qD6cL^1qa%^mzPG@-CG))6+1OizJSqWJVjol7w%hJ+PjvhVA^nXn! zGR93*R20__9~K^C6|a2l)oD1K4#xg*EE$;@B*Z5$d)92q%gbvVa*YczGB6B-Lx&F2 zvuif*y!{UO`T3Zpg+tSDI-DVIRDdMQcmlp36=AbOjfEYYUb$|t{>ay?Ih=!GnSj8u zEiTXKOj1MxHXPr>hq)`cyVI3q#HMnlvW$v=4-wiR(=w}xxDa7er>ZJtWo3-H{Z97o z+Cxc6390F6Oqw!@p541wo3s#!Dl00vY2;1p*}aF=D_1l5pHt}A>0+Et4a>H|pbP@E zdU1=2irBGz2Of`yoSf~vJM&$l9f+G>DJp^=Yp&W!jYRJgVkz zW234XmEKD7zsqM=?k=XjIt{}#DJd%Bs-ahgwX0GCEQ+jD{|)VxrfB5n=QCpXjT}FA zoRX4KF1)ZMlcxNWwr$(g=3$$+F#ogpwE06@?i_mu?r1m9?|6>AyZ5qw{dxuu9vlYS zU!M3YhYlWM&+dKr@bcIrk1=`D6fV8&Qbyl4npUk^Ap}SQ2}zQuF#MF)b9v6XN_8Ir zNiaC08)@z&$_yVBhL7?<1!ue-%KVjK(|Tb-Qxc<-cw*IC{B7+!>^gmroD=&AmdeP^jzy%OG!{t zTtbf?J!spm9r;D4`0nr#Ui$k>?AX4e#u}Q;ojZ>xHJSl~29Tbf&fvj=k&Ga$3w{PCO8UOj@KTLaLDlJ=OakAhfdB^gY@!AZo9&$DBz55=PWd$`GLQt;z>XKAh zA=VLvTXW%39N0kf1}R*3QC}=9-rctbf53~3sIwW8kT~KgWW|>sqS!W)2x_`RMJFis zdMNOg;Wzaf(Pc-yAqLa5NKbFX6Hh*YX_}0`=WfQ{c^4bL*uWkEG-z-kSLq5fLsTA|i;3k0&K1IjnX3em^tceuu=kM4Gi|&euD3psFgZTDD~K z#?8dUx>>YnF|Av-u7hUVwsA*CbMK$-;Xm)qLz04do{Kwos%2tL-3X7%u#CFlCBq;y zE0d?5c@o1gxP8TF?z(dvE0?ci>*j5&ShI{qjT^CN&tB%u2|g!2KAzR9R^#*gY16eW z`}XW*+gIDzzI_Khd-e$TotT)&unVA@&g|@K502q4pP?nBd#`2}hDfg5kBvA`sH9yD4#?iA^&mWi5wJh|| zb0sM_wwRdcdfTpBx1L#ZW|5kd#?~!cxn;wxAcDM6k|a=I`m1j+YSbvw(o(}QMj;@e z2e2)R)~#Ff&=U{x`qY21dDBxCzyZ^rMA+;^L8la)M8Kw~?^rmUlq8)wv^`%SK zF05O#j&Bct%l`fQNli=T!}mX++*3|;tee-SzQ#0*5GF*Je!M&LUAlMgh9pVBa7@S% zal760zpOvKFX>Hud>n7hc$)?hv79_{DyVpE0Uq5Sx_o$aKV<Iq_FW3fZp zRX6k+TC;lbeLnw>&YoD7MZg#IQhk-)I^JAO?flW~ImAZAaY^6aEL*vhHD9h_?U!p< zv2HoTZWu)x zIItDlGBM6FyH%V`H;iD)1Hr+A2YGeMs|X<|EGT5%+I8&OxtpT0LMBd_Nbf$q0SNem zn;!55YK0M+G;PWgPd&jsE+-}-g3{s= zEVKT3@7(z{pD$WUarqgpyy_}4LxJcj8GF~Ay#LmF#Ky(4VEzI|+&F?jp!!ghS5|}% zLVSEY4?Oq)w~QD~WOM{CJo`ND+qWYjA%V(@${?Mvs@KEk^9K)sfL%LxF>UHJ48!E~ z>C-G*vXoPKCn@q2F=g@;+O}=W6k|`-EWz+gic*-ic_x|x*kkKNT8UQXMU=O)rsU@I%E*Xxyqgs9<7U66SJ zM(8BY|9C+-gpi(|PX7V@>DH|q-FtKoa~4^Z>DjFp4yO}QBTONL;HsfllAD{$x;5*F zh=>R}3X+5nf?hqc(KL;?`1o^brgNs}grG~W4#={MEGuCT{cP>`*@A^+walb-n~TWK zu4}3@T4XSM)D7(0vlrX6+4=QO+-^7ByY$56a*~vsRBLRPUvW8O#*HR-XHYABGVc>^ zyW@7Equq4u)DZ+!DQVlbY1yh}I4C{$qj?BG(NtQuZq4;KTu+Z4J?PT43+3hIY}>Mx z4i~qlO`BWlO-&8CY6z=Wug2%|^6j?=C@m{x$j~9UB3$g>Dyun(?necQa(pQ~15* zzi0xtV{&u;-2Ga5x9Awoe4WGb`@U@1kGgOFz`p-mpTES|>g((K`nG@I+WeHY{3#L$ z1Slyk!Llu)V`7LnE5)ZEK#}AixvariY)N8Z^N##^(QE84JdOnHE;z=5Z#Gk2;RzPP z{eq&(GE$8o>rYt>@e}U<+(j3^U_0>(%!mIGu<8>{(@08A`fnvh0!30n9$YnmoU+R7 z9V;lRC`A$i8&G_HW^P@AWSacaWsX9KdgE-r%Xw7KCIqOm3TJ6HLl_kMd`LAEu>eXv z#s4Q~MHVum9 zk$B5|%$oZFSy@@^*}aG7pL?EJ@6TrUuHAe%>jP36rl2Y+aqb4FPMK?mUBknVJk06* z(=3?3fa%kwvwX=iUU=yRhFyOhLWpoZUPQDDhqF4D>T97H@qxgVeM%Y}>-` z4+d+sa}+L~cVM$^z_!4$t0T?AqIp~jz0)t^?c9}ELdGr0|EKRK1PnbKpF4g$k0+ma zk~766bnV`iH{X6USRe>wWoDiu(r+{GFq z^Zc{VamgjUNlH$tRYg)0Y>^QWV1w>6>C&~!IjddIcaf?}er(Iewrot(BHz$4EQ?Xy zF6Y6n*Yb~nV`$c(A*NZ~?w7v&dH^#NOIo^Q3CE5cqp+xeVZ*Nr1MR=b!Gni*>$P{d zxJyUIjlT=qwt0Qpzc_t5{~R7#bp?oS7#uoy5E&KCp@p*&!OAc>oWu_EOIe~{zrK`I zmT+M2A+~JU66RlF&rH>70MCgHI9(1rWgbdPN+~HR!Rz(?wj6{|QVn5Q*rrK~l*T;T ze-zysw<6jRL6@c%vid*{N6wr?vg{wk*Mh}*KTZn<^gvj%ZQs5FNs_QIY1g*h4@w)& znKg$q-U6<^<{AbL8psFleZb!R`&hngIX8{Esh;?njL++3#=qZW&8ju5S+$17jT>?E ztvA!5V~6j3W@2Ju5MZ*z(IZDWd%8-E!b~mn*%j^J`fG2%wrzZsUjF*@Ul}!O}R%jpVNaRdT^@BQwnQ>Xa!(}iR-&7@oRuDIP%T+*jEh53cdn>UZ4S6@X` zRMc5uvvEbZNKH*;^~%*$R(N^trDqv2V#GOL{y5w74Dw40r_;&V&SZTr_Slxq@Efnk z>2y$5>Y;I?M!z*P3bth;6a^s(e10$QZdgR$Hy>m3v3=OK#jsA7GWqIzNNNxdl0ZlT zTR=eA6xljmGTQLMm1B8!$QU9VE{e-b5TSNI^zxfl*z?WqosDH!M5)2D=kJq+pDp6> zfkOmzKM&mhAR~s~$c8UBAS(*%R<7sEFTXrzfU{5EOK8`jEviH1^My-7A&MG}&k>P0 zeE4w4iYlS-gxZ|wrtr#AP@*vLl_1}jvQgZ`~_H+MSgw( zZ@l(80dIgmKlA{%-F7Rt+JCbulZB z?qWmUw|rB40?ncQq=A^&;Z1hC-8}cgvpn+9qio)=nW0x)L-(HD`NKsQa{O2xE0!#0 z*of=6YUovL+^~TYc_(R+*@9bd8%?y^U9DYAll3dVV9VyO*u85vojP}-yu6$$>QH5( zTefV;$XiD7%;V4U`n1>S-lID`dUQwEbt+0KaKr>(T6y`!mpO9eD63blqNKEx-hF%X z%(G92J=`i_+P00?Tdg63HTyxc>emDORz(3Dzh$6kDw1v0;vOy=*6iOJtWT;1B0=NC zMtF1|!)HCi;gZu>X3$i~KM`z(B!$5Skd>9i%=c%qV%c)$f3kobTeq|As~l1qCUMR1 zYZx(N1eRs9Yu9f2UfP#Ig9p|E8bIH^eHnJsb?5<|+}t48qTFu!UD^*>mT`o}Dum#U zJI8SN;2|n1D_OdD2^~9jKv7h#96lt>&vJ5h;Bq;+anwy@XZNCg`}RaeMOHfknuB52 zT#u@%VFv=BVQNDLTs{!XvT$gQ-z?ZfQDvz;cHty;7oMcNMh(?}|5nh=_nQU^B2=*V z^F&i+Ug&iL*R}3jFBuO60$7%f8hWR+D$Ee98Nd@wvxJs#6E!zDszKl8U`Mr|9P89-K)T2&W6L>#L2iw(pU zguT(TMOD>vdWhdw?0EKlNs`X_a|-qK^<#=Dg%I_QGg$8?^II1Mz-Q{Hl7%cQz^_qx zVWU}ko!_JQUCu|a0e;)YVe7xzKrGwBvVY%@^P006rfHunFRDe(e=Uh|M$y0-`TMS% br}F;*>3ZdX{yj7z+vX> zy}Q@?R);CdOCTfQA%KB_AxlY${s035M+LqYgo6Q|Kgi^$z`)SAJXF-3ei*ot*g4pm zT3DNqIJw)IkeIkxn1X@1tyQI2IDmMg13yP-4&gw3_gqk48oK)kd<9}A%m0->dRVvr z0wpAdNA5syZ4gc#y*&zE-0FE9sJtY|oH{S!J&lvcu({6-9G%{B2EToLcx_%*JiXqJ zTs#I3Qta1we^edaa>`ey=%F^gbzSu6`L4Z0FZ4hTM_DN_ zyo7A{Sa9hUw0PebHaipin2FLq>vv1ecB0dC^DUW?9sXv$^0&t4)^|Q>)T&|n((t_I z+OzxgX&o&6)BCMICTX?C&I{(HuIfc1vwPX-dB6QVQ{cThKefB1P|vn=7jIWW{u86e z)3xW(%Z1a?8}8@3&A$|DD|=V8V$9_Uk$M&a=#xkU|?Kw79T{}dI|XU=jcgS z#`WC9*AePazwc|z2BI9&bi-pU@L&^6fH1Bpg<$Up}pdW37)h+&09UxY($fA-g` z{>GS_`)lX{jYY$F;?%dX6Dh}A*@(-ovcegY;~-J2iX}SINT znwlpji5gXhnwC{n;U%UG$yydoO-t+M4O%l@cWYXg-gn&dOy-q*1<@QYxJL4nS?(-L ze?P{c6>PG*-0lxILuR18)O1b${87_0n#yYaMZL4?+_L$&aytl_C@*lZxF9dMKZ&?h zz3DYk)w1a`k|o&iF!k?J~CD-j$l zE3skLXj&>wdAeWnAvJ06TSRv!*HDi*aP5((Fbqg(Y4cIA&> z#a?DQaiyNP6v>G^!mZ7o>5va=RmNn^X6`m=3w%=>@gd5+b9;3>l}ixG3QC#u9BATmwWx?~eTa>9 zm%_i+WZs5Kp4X8nsdkq?Rkh5BxbGYIj!u(bELI&D#w_RvaBt>y*$ zjM|tIL@@u_1g(f{_!YuQ(OSP}i=BlqCu^7N!myxm9$M$^jkV(NooXRgJq7U(z1e*jSUU!wvjec^GZWEEOX264H zoL-TzwPsN=#hmc|=#^zj@dz8;aI4-zU4-%f;UTM7Z?c|bXSKp>$!U5{HbIEJa!(`c z`}-{GMDgR0sa0$=Cy5ZjI$7qg8V3o<-x!#fQL`L@&SJN*OnP=C^lqO5?*Z!yVRsnC z8%H!!%5kCAJV~@uB}t9EIKfnSx2`sGDu1k8MPZNcDINy4Lg=DL0rQs8vT`{7EvBsk zrXo$?kWztfmJ)W*r;Cp|<$NIsEIh5(s>8et`J&5y3t7LtjU+f^1J|%w5ZZWp^tO^wVcH0dH=lqnX&Q>R$MK5*V<9qPi=(R_zxhg)K zJ5`ss_YniX)?kGU5xepnU{M9(wdt0w*J{n!td0mkr=i{=8Nizsi^CxzI`E@2p}vOc z+MO{%m}4>E43oqj(@Xo4V&+?Wg2^Lb@!>|4OqFxHi8`i~rZ4RXWRtFvru}YLf2*>B zotKOvt4eqx*|usZH0(oW=vu*TwN8YBU({-PmZ>HDNUZUrNrQn5J^ji?5I5_bdDY9* zX9XW(w1TvuYEaB|GiHA{T`hxO!3XzE6P}AnxkFg!vR@`%rqx@$#MH(LGOT-f>}ZxG z+S;Hy)n1)@d~>0X+Yq!{-04URzjGeNtjAPm^zVDYXumGC70;PZgp50aM=y@{jl!Q{ zvq=_3Ia{@#k2s})`gTkkuwlWF9639Bf%b@f+n=%meDkPLGVZ}NCI=y9yk>2<~Y zK!>A<&R(#5hx(uF0%zmM|GFB#CPwqe&lH(IMtBk3L81mhr>b~i@?L(|$w@5C2XDu0 znWluP#5}SdMU}jwDh@+-B15pFl#a6WG<71e9K@M9WR1=%GCeyU%ryF@R?6j`u3kAj zWEDNa=mV?MF@L~s5>g~1MnK*$k{c;b8Ygn*&Y}dK{j1Z*QCd7xB8Zw&{bmH8^W@b^`m z2a0{?`Ap+}`*CG2grpzIZsPM{uG_zL3wLO#&ztoQNTnwc{FLm+`xANfSzEuXf6pRh ziq0E&C*8t`Ja5Xm8qHAL9ZX+hL5j2~KCkso;^n5J&dkGDjpExw|2GoPdqrmY2jN3j z0s*z4{@VT+F89n?zgJU9JRHunM{7{%3mW=MIL4h6ax+=K@FX~gv~$&Fq)`3`riF?k zVvGL}+@lCujf=JuQT(QM|0=bBia(6twKL%zWec629Y>KsDaJ-RdDR*DxU_PaCl|BW z_p8K8&>g?~$&aqz_KJ9<$q>V*0Jp9P*WuFgiY-# z1`O6@mDF?nJR(hrB)!RWWQpj=Fo|3h6y+r?EZKL_%IrxpKl=J5#bT7g6X!(JsiKJS zP(F59MR{EaAHDOO7bi#IosS9a!^1HH0vJp9m@1%QN{|P9TLo9ZRk%%g!Lv^KgRV!;QVNBFPNuJ=aHqNcco}?A+1!Keq~`xsBoR7GE9>fJ z6!-Y2w2dNe(5QZ%Gd$J6*47mY3WD7_krM^OS7tl*iV!SCd=}w(2!%l#3OmZzv5x(n z;7QjbUnbfq9gfqVhDV46@n(lDb36YlO|Hj|HIp>kC|z$s;g4!Azb|Lup_mptAEPz_ zVlU>`c*Za^h9$-(1a{!=eCh+MJl&AdY&53+`gK}boeFr1+58dx8WGkxgBe%ugFJT; z9*@#2v-bh3y~##eV4TxR^MPbIv#GTNR0Moi?xtut;@ov=G}@GDn#t&eli{Ui&Ow$j zr51@8MZ99iBzRYFoO5`xUtQ9@pz_XF(Y%9;1ef6h@CcMhJ7Y;ETa@sqa#^Bf+AVUI zCFFya%8y>4tXATzs3(LcnK(%hx066hu!?(v=AQbMW1>x|^Pio6>_XWw}J6K1G!frOv)3+=l1-ZDIV z-~{6i+)+#0*s8gXMK9s9Jx%4Ho~1R?};44zKoNE+8Q%Qg0IIpK~_&50;7VIIT} zJ-IpDgCSyv=q{YFuW-?U$K|x*MWuJe*$@@NN@`h{U>V{h1~e6@!c+g$u0=XEt|+tY zSZR=9=8?v9+c^6DoIQvI5m(TmO&!#Y_q-|tsrLEWMiv9hYSR8*l*d7hU}6SL#B|CH zIvF;^?quuJ;e9WH@(g5&ySfc*bLh9;``tIRgTJnr`#9!80>(#VO^A+%g#J#?2S;0m zY0cx?mZ&;h{Uj0$P3|DsGRXL%pc37TSpOY;{C=Ts{Y!VhlRoMjrE{boHW8NcSC-oT zfUYX0aTvZ85aD@MlRAXkf4hothq$a#q6nL zc=@vE$Kkc3Nc$m4wTYWCKvZ9INJIk3>X5ZF9|0lN;_&k)vY7i1g)Z|~&Ei*`BxdJ} z``AB?g|vh}iv%(vCvYCZ+XB-_ph^>bP-iCBcnP5Mf2BSB#sQ~Q;N_ooymS9)bruCp zehQ^NGRRXY_Mt(K4936^f7uQ5je|zF{qyIXQ)c9~I1=KZ_;0l)1FKqtEXro;ypo^B z!n`(ILWE`C(BO^D!lWV1+05MVtLf>J4_|r27gVBVg{Bwepw9ZG%rOpAy>P|CV2YsX zr`?3OSh$(*n{S)Yw8hxUH;|OaG)jSsNVuJ{r zdCqL8AHP$Tlb_2-Xy>~(k=dG4v9P@#bA7r#K1AbXF4+eQN|6$ADd>0sE$d2Hzwrt_ zAh-!`sNRivuW&^}?8@BR&XX9w$C*d&2&rL@PaI+uJH?N?z$i-;v}@ZahS%QTOKi|G z-KcsbuE^6q5{MWvfPHoeiu2~z2{CWHulMJoQH*MxwJ{i#Tv##?--1pv%TdIRggVwo z%CB+kR9j5gQ^-yr?ZmbJTKES2Cw50l%*l|g=uO{10%SgCC3>7sWLGhR9u`FgQ_v_u z95!nxo+YaLqUA00+OCe*lcm=pmNI+ph*HH$ex{>|1irenT)QFk`ZuJwXHU>cC6Ghh z9w&g1@RyEv|MnMCojl&zmqI)zTyPw67;_j|Z26M$XUNQonL>C@V+E39A;k7zYnw=- z72_OCh+nfzSp@Zg%HBW2&5h(arBQSE;)ucwr=?)~|EW&&XB)D~G4Xxs&V96kS%A-= zWq*RBdhIt+aYqUQ=qJcRL_|?aMCAWT55RXaz2kT#2Nbb_Ka0arzoEFG*^esnhC(Z$ zOH_SF{}!m3rwW^p(VZ(L62G@c-HJotMW;rno$2gshYi&S8yZoOkROrgX7j@weLra6 z{ps@EcbfkT+CP(G4Z?lD;&sz+=wDb@RqIZ%`Qensn7^@3q$U)Cyq1^wA-sAbuIWBK zeX9l7&6r6@FlJZ=8pAk4OE`Z8+zQyIVw)sAWx7ok!h{jGRlB_nx9!yTvJRN?uqzNy z;voL2Y-FZQ$GgU|7yaWj{)L68BD00d>foEz!X z_W9peZdXYH@Cv-0q=q9H7z6o#4{*04K^NdfSSKktG1z^0EO<(8LKCHAFfbA@DN$h+ zx3%*wGbiokq@&l4GK}(AM|mjtLjutsWZ2YU@m2a0_0|(Mu8*w)vI+{m$Js5WTOL+j z?3WW-wx-Uh*=ZLpAzwpmBsBV@lZObvW5K}Dz`!xM&2k+!?*umbp%I|{6zeHaA2e%m zCR87@JukRF8wPwY=~Ig>(*5pxK3_vg;NYxBgRMtHuveMnf)6QM*=NN_aGE*Pyc|j2 zV16+=f|JO3RbxG**pM!#)YU<77>fw-x#2oR-`{RCcJ|MnaphCTr|%6MNeN>NZKuMr z%O>b^4SF#N*?Ta!E11Ua%e4qOi2V0gTy9{4PDyQ{Icc+e7Bm9V4!5IsFFpkv$zf6? z1c%d##J|nqre-;}K7e;40gVa{E`_EDO6#vA2(PcVA04;NROmf*!i*MA%eb|D;{ET6 zC|79Ez;D@)vbbA(ZN;YAvyC0AC&?$-b?$(N4zO0I#jhtN4ml)+8;uW?v5(E(8_45& zl8Y3>_5Tf$9Q7X*uQUm6Q^pl;_#x*yKc(?_-=ht-`8L&C{DHsmynM{7J~9^4~8} zJyE6NV@=$3merS^(9lG|m9WmA)n|CMUj%Tr^5`ZWc$twPgmX(=3%>|%gX}?&>?&TeR^`t=@9(#sS67WW8>rf7fwW#un15>-Y1&}RB*(!>mi|L zC-oFS-#OX7Ya+>U+R-vJA0_UpZj=TOTZ4Mt?1IB)=|0p$OK~C~$6#mqkO((t(V?09 ztl^;`Hq2Z5Ha0fA#~kYE65{tLSZNc^&gkYB7VJAVX%%?F{R1ftT+d?$&bm@1IwitiG(SYCJrW$HBSsT+D87_q_OHLm6xI??1$5 z$H|+DQ+BHm>(+b;8Myw91KdI+WMs#VrX+K=_I5ofEc-U?Coch(hI&aEg@L2{J%^4< z{Fi0tth_!_O-)TjRcs}+v8f4@k!eEU7HqgizI(qnHZ|n|bbLEb6uxC$&QkN+*R-#N z#l?HK&vtE_p|q1B#ivfLu8u9I40Y<0Fu9F)Y%b?pVjki!sgw;34UKFqOuVcexl}c6 z!CnHy8X6kLrY7;}$>QN7x4*X}xCn|1x&nTGfXm8e%2WNxC@-t5)aLXvPal37ggS3o zZ~fyGapAc+IXSI)xeOH|(sz0rpOz#hEDWBcn#6b2i_XHzx{m-$f{?sQHU?2_k!$&{ zM())vf<=j4c0rQLyP@E0vck~Qg?ru6VYWf$dgQw-%n|O50ux@U={G(JQDtUXqgD6@ z(e>KCF`hXJQYGw+^vpfq_nAw(j(eAWDF!fqYA}E3Z^dC~g7$nwsA6JbXM(jcA`&Rd zS{=|ijLe-?Ry-*TI(s=E*Eyl$IGozz;yl|g@7w!jB??r(!WF3ppBs*z`+DAD6{(?w zOiZ9ZuRmX6j7?0|mQkS~I&Y^tI$y>stZZlELQpR%&DmG0%z@qr>rS4G7?w0}%CpRccU*fKV0S&dDOs{i?Vi3;`f zJmWk4*z-CgN`?+4tbcom_i^_5#OiXf)^>QvB9~BHto63s^G17kc(_cELnp_c zk(WM4o^0$K6a+&4T}zeh#)f1F$%kos_dirrMzgy*xLw}3w3@Al8&J@gkpi}|bw3_# zKceyPIhmO4J9ILAUf8x@x@_Z;lAO1#0>6{u{jUKTB)%xUZ7?@QC^|CZ}%FU9S!%}h!`K}Ma+l5TkL z393(M_QT8CI`Zkd9Og4~GY|d!Tlk)qPYoQ)Jtv-D*xn3i zAY6oeE1s#Abne#=?(uKqT7G& zIT@&& z?ev+iQQ{WX*7+Bjghob24<7ukhkO#z1iO79MB2))p`r!a?M`7m9}_*rO$W_YQtcZI zS`kamS>|lgAPnIJHeiXAl@}f~wG5lp(?dw5j)Q+2e0cPc8yMl^iQF2=h8*D{%U%mz zyGZ|Nu%D3E)`eZYV3SsM{0c9dhYzD0f-e@IhBT$bxiW0M5`!+GJis)I3@L_&mU4!z zgpsi57P8QHQX+r*Y$g~paU%MM4QLTvU0vGkZfUeBHB7Ah0o#`)MQzX`oYXY*yMG6$ zFssRF5#m#y$dEUAhV;A-R1pvmN~^2u6?l~r%zwZOVS9X>gs_ld{N%?&?; z86DW>-Vj$kT@A%*y(KtvAd}*LQ?{U@s#*>25S|bEm7X6@=JSc_6B9rG1L*jI#JM#s zDZ?32Z0wtchF<`P9<=U4$r9 zGklLo3WD!g`5?`VS>oJ|j*v&+XMua?`=_IKC@H91DD$H2n8-+=BIp)9t0C5@BPn6gFfbhcC8t_dQI1+T z0pY^2u(0%Zd2Rx`n^}M(8YDyKj58jALIAIfHj8wwj2Twz#f%|87N+IGL<|qqd`d8R+A4@&5SKm@zeTnCeo@b7qHQE$- zVfkfvKvTcESih|Ge&vv!kiL8s07;4KvY-f~jZRM=_VxuV+H67;7YPTFqZ4_bq5Hh- z=iCJNK9edZDB}kA{e1%2zQ56<5b+6#iwAeDA}tsjyCt43S2MG-H_5KgS^Vmf^VZ>~ zpcvv>o)t*uVfr$h5|`GOHxN=*Uf#%_X~)$+5F^l!U(+>>`f(5<8>yg=3sXQ&-uZaJ zxYAAF;aQtm9c9U{*|gMozc#CS^R-tm%+G(ntvUobvJ)UU0&6euIC&%~CnJV6i0$iJ zbIQ=7rTL5gyOJ@H&mGA3s^&8`5{cGifC-6gj+%90VdK1A=C%rFczDFFgP&GZ3xux2 zq~v_QoQ_5X77j|tZymFHrM@#!PD~eZk>(Kyfvp*cQkE8VQ>qmj95ki_8pZ|CFgD52 z5GkJckOq30kOxij%qL)yFf&X2sdERv?E8#Qc-{v$(Y6@Ua2CE`bp-l$w(< zFzX!3NJe^5tKFT8-peLCEI#6Pdt_U0q0Ek&7EZFW8{Sog^b!$ z1Ufs&YS)vG2!&W+Qs|RJCrbyAWdRnzCCzOw^}P)J+2uZBdcsCwh*4y1t;o}H+wi}J?RqAxZ8yng`oNUS_U*EO4 ze1G@#^?^PE4RAL)JY#}i@E5JOGL2Cf&58E+_a~>PDJt;+xd^Bs-J0c(m809ZtRyiL z6AFoFf!@hQw$(Q72O#UF&!;xb$6RRAM&SKG~{%USSCZBqGU>khnO{&x6eR1 z7Snye*U)9tf)Xa-Wc;Fp@s*HBY=I5fgVzJZZ?QKw4&~()v4f(R;|L4mVN`l=SyJ_h zox49*Kf?!v?d{_Zw$=9}QG)zM$4}@gRI-XC11Y zN2*lHgal>nPP@3cI3?Ky5XDfy!hxZW%@#`6mbrQEff}NFo4ZuT;!-vR&2FR4ES{Rf ztaQAnni_|@dz)l_U>-h4?9}Vg*|i2`c;e2s5SS7UI>lVJj8uI5RtPnMWPK)EO>zvq zBomsXx=2}BSz>bXoE|t`jrpuH|fHG45IN*Pn5tl9R;i1{z07|6~E<^wGNAK(74>Zxt z`tS16Fk1B!)G{+v3t8cGV^E7s=L-Q(1*C6taHRD9GA?$N#`q>>ZIb?>VPQ77k}`e0 zfvGn9BMuCs5ZBi3)|)U0r0&cNbWiGZ!NnGUUB8`p^cgFPNIApoKX9;^(sS+S*#{oLzuA zTQErKjJGytd>mox?O97UQY$Vi>{`m_dE$}V?g}hTLxVP2O=FhN+NlvGBWc(g7iM8$ zAu>9WQgp>wqwh$jX@qJ}4$C0z(7>K85bE6ZIgc_l|GZU2(wHVD4i!)BM@oobvtf{bXIP(LK7RlCLwnBm(vwVBR85T~{k7 zWTa@@O+im^0!k9pxsnKJ(^wlQsNT1jPO_-zkwNekN`~#h2AYEe18jEEED|!VJIZt_ zC|1nDZwYtPQFY`cnc50GlxTB?fq$r#_^2M2R|J`h;S zn%KBRw`_T3Y&X0lhaivmqhA^Kl#kyAe$_kT+8o8CM3Z81_T6j-X!E`qRc{*B~ZuerOp(mTN2nO z@u0!k0k;>>#7uGti;XJs0$%aI)?4ORSBLKI*xHLRm5nw!?Gu0=P2_P*iY2BSXJm-& zj$vILM5pzEFu=CA>!QdW6}M}k>wE9a!*f9w79INc%ilV8M>+qcHhoKm<;vvgE|3&F z%3_d+R6Gc9=xt?bFE1XH!y_kWXZb~KL1kk!qDmkt&X5p!zQGAS-?aSve(c|KOh_*t zudNQx2zBaMQt{L^bVql}B`4ixRbz7oGL9M9>3jV_hy;9IvGeouCWqCcs?OLkd@oEQ z0V2R@(drf&Fm`Z=Z`t}X032tlp0-4-Y>|0X3TC&x2r%!@;hU|!sGCEzk8 zqJJHf9buI9%1l_)iT`P#^IJh`iZ5t z=Hf3g{z!rV(WjmYSfSs!X_XJ~sY~$f+6OrCQxHmMa6#WQ-9PDkUj;v+{kCljI>Myn zZ5U<*KI#4wckCG)yLLQkDkUj;EtQp#U%KBYUS3|>)}6lqn|i@MkT3To!}V6 z&I(mdD});;yFA!{iH%ECLpH+7Ur$Oo?5&?K|x{iHYDWa0gRnjV*Y(-wq4JZ1>{D} z?zmy1$CsyOZsvhjKWmh`%F+u)h<5TYsY*9BBW6w;*csN3Na4uRE%b2@{{^3gpoLH; zT-@x%VikXP47^wOf3O$4)xZ!#==@4Rv-z7uI`iaLzQGfUjxyw=1Oyyhquhgip&KpH z>XWSCslg@}D7*>}O}zSW?C1?c)MU z%NSxxTsOVw0*}un&(oi8DEs@E0BT^!?CD+lr8PD>DhU@>t;?dMZklcmbjGmg2j0;9 z(DL%9cOL*bgo@`Uzn80@4Nqlo-p>Ej+k3h;vJz)XwhaWqZg$yxywrT)rK|^4;|Ge6 znwpvIKfICcHz+mBkFnlt~|4jD8WE+Jh;KW*zAfhzAZBWM*v{G`rw^Pg;a()=#xmwI3hkjKU0Ux zN?FiEn}F5ZQD^ddMhlVnk-%YLW82@YD6z1x+P7=}2U4`6TF%V>9FJ0enbtG@{QVAq z9|{X|lLM2tix!`h@p9>{*CEfw&F-neK~Q@6s#of2nug8d_Kv(KEZ_+Cfx5x%{UU{c zfJi47!DR>MU*uleeqC3J2WX{)_~|WwJWI#%<21s9u8w=r6u%J<`F2?aO5U6DR3uop zo+G(-vtT%J5gFVfh=JP#*`T91{uAICzT6;VJ&x&{a<4Goh*8$W^tt`*Uu6Y{t6Aa0 z!=sb#<*BM-tvjbjv@6*Es#w)jm3EIulfJ>BrVTfa`sZOMQnEV^g@njI^fv; z>#L&u=))t!?(ct=SA2JRL#AEfo*+58aQo6lQYL`AUpz#!sxlOv#rt2+&EQb2`Io76 zSsf{!9~l+Bnqk!kH9bB}Ny{4%e&f(=vp(Ma{<>1<2>aDOxC}s;NJtJ3@2r6XI2nz# zrLhGg(R=qE=wfl_>n(M*pSnDkP8eaLKoiKmAgb2xo}AX<#UO*Usv-Z~`Q~M5Y3X)! z?utu<*W9E9G~ooGW~AtO#8s^~#{+m3N*WN=c!vnFjon?b$;rbFGyDT#2~?@TLAP1f z!NbGDJto$-v5B7e!sysh1W0Yp=PKI|oDP>uX|BNhIqIJ^x&=*5*dm;EJXvRJT?snZ zXm;Tv1HVOJpAS4QSy@>F3~)usm8x|*#ARe?M`DO&_M$@lg$xY98NE&-2zY%Wi;F24 zbXp}h?>)rD#X|$h?JWbswuy7|x_@tN3BbU@J_iWCc>^2)$D0TaqC5{%q~r6MoUZ)d zkPlJK_?k_@zI7#8xi7$^Vm53D`sSB08p#TtIkgutH8mUszURQC(8AM6S%v?K=8Ud;l3wR4 zBE$z60CEAvkG{z?IyGf5kxXBqtG7%S-boFYn3Qt8@g>_V*EbfQ(`NYIpx5)%`5?u*gFMCnQlNJf)rf4aZjVOp=ZJ7=wh_${9vQ2w^$kXX#QYV8oG)I)zC>Fanl?63lF7$WvbImbnYonc^yybR3PM} zIKT-OL=553(e2#!_DIxAMsgKy z)GY88fgYZel+;X9%ZI-Z{*=nm~%*H0lMUcmMA=fe$hs`4PWw^x& zTqb|?kUryi4^gi+tJ-j5zXy#OP2fAXf6zv5Q3N9+V~nx_pS-hk5;C>g%8zDYkLP!! z0ev?&po-7z=RB_WF0L;O2!0F*?q0i{UCPPIqD4cF^m}ccw~r*r{4A|5Z)apx;<_OUG9I6~1x$u#O`pt&3ZN$!`qTX^U3hMNSn3Nkn zhg%13_ZyIWNyKiPfeq*Oa42u{Y0a|z?XcW9?YKQp)}SG?W_pH&g-zX&0A7^+^_$qm zMjPhQa;j(nb9+Yz<}emY5YS8{0KkDl#M64{nL-iP*~wR5U;kfHgDQB#bTpH7bicfn z`>Na06asx$0s-Yf78Ut~7V=Jkc zun;=EVU$6-O5vBQo9Ph%XD)DZiXSbKcsuUPhP($Z9;d?*EnV0iPKpOij(ndK^+$p& z4y)`}t?Ta?+vPoVV#cT^t&$ANZgjwz^>94LDd+r^Uv{0GgleQYCce4J4-_XvJ{*qL zy1Vw^DhNdXqWNqtcYh!t=zMAT877{8oSqG5Sz-J_NyV%or3Wx;o0NbHN=i08 z69@Pn=$**x%Z=WzE&vB+hIxTaP0Q|MlM4LmLUUuKnG1>)G_7D9`FaNGfylo zE)F!032O?*1e>Y_dw#j+aXH(F0~<&DNI8hs7gHJ?udg4P7@x4~*z|h6?4EjB3G7k; zOe7qPNQ4O-I*H{!EB_Gvf00lnX+H-xvd_quGsdDb9imgLjTCGsh=%I&yVB03$QSk- z`hzVK%dYr11^jE}G7AM2=pC{1kPzw~SB zh57Z2;;^>x&k=2Py0!5L3h+Ru0o~3e!=>3|zYOYdEFA2cjoX>|xh;9C)APbpOBs${ z2RoUS$?GUn+#Fd{q5-COjw25%b-JXvxc>>Sr>EAUsrF2jpFa~4GK{W`_evTq@RIc? za0(jsQTFYpLy9dLo3{>^sHK9~R;OZeaLR$(Nih+qKXB;=B=8U@t9^ZYnt(tu$2&c= zzNGgslTdQLhZG4dtdI?QkAx{siwIWph9Q?Pk57)PpsV(8p6p}_t&06;gA4N^{;UTd|B?*al*0aKhFaELdq@KwyX=UO$&|6W`0@h+@HZK6M7b~KbMMiav$%AAH04GpY9b=vG$p%u< z|1x2XKbPI_Clw}WN#ZkoQAMwxM|ai>aCJ`#pO2p+0POZo zb#8REJO8H|E)%X~oYwaN{_)cD!D*~s9cfgR68U!Npz zIx3%9MMW(>IXSqqw~xfeuyO>x@p#tBt~NJV`M~pvd*d2&x&QkB+es?kUJqyYSl5XG znI_iMEWL-D7bIWY7_hjRI%`-?VlJsg;j8xMVK`Hf!su!eU8nVGQg#P-J zvSC9hpI{)vzHr(QR>DGBAi;?KF9>mTWHs9slyP-!c?V$R6r@Cq-QV8>9BwNSuEGJ# zKdb4QckX@v)=8XnE(|I7-YBns(H}k;7|U1^Qf&eyNvaw>Rxvynm^`n@Er=@rL4{IV zSWvdFGU^xDbva)ftKRGY;s{OE1<$A5g#Yei-$x%5TKlHQIzkdHLV7x%LV0@SS8f87 zXg<3@0FXvT(J$5E;^7@8<#c}6u%H}hk$Fss9sr;&{EPPJ9U%j$&atFtC1uCj) z)p&Ea9$o)ApojHym%Vj>86VAL9|1oEq!GY<0l4JkL z9#vir0BL)`QN1!HfWqY1|73Dqlh~xtR=W@WB3@wn2`5!WB)*xg(tT3)e%d-C_JD+B7!)i>;R406sF7Tb?@1y;HJ2T|gMyk~aKXw^`P)uX%< zo;Y*R?$(@nV#~HiM)IL)gtyf(6ofEPc2l#`nnX0}>Goyh4n`qE0F~_%@J}uSqna>c z7-JGkRrM#`7C9UQV7&P5+e(hl&I&I)bT(29EyU`$TItsw@#@ssm=W~TMGNNGF)IWb zd&r4i^-$Ci8{hB!X8=31(Yb1D3kPfh;N)T=5LRK;tu%Y{va@I4e#NWHat#O&(e8Fi z%j$56&&ZHrOGe6^Q^)cSx;{8Oy!m@VW!||eX=PPWu{m2K9W+?n#~RAj1BC4cH2 zhxSoP1uAy>sPJ%vXhEmUeN>`?#zrh9tgm=@g+*;&g|Q^Bi%BL}JRdGd1$*?TO2a@> zi#AmzfWn6g5XP4c=p$*dVRzh_6$TRC($Y$O8y6JcUIi|!=G&G~da`Efa5ryV2M33h zZX;fWY#;x9RQ^y-0o0p$CXCcmo@f5>h!Rf1J<(_RmqV&&rFH17f1UE85q0D(nN_P4 zvVo>;tn8uPbm7-hQr3h@Uuli0C@qKt^uCp3yJc>AK2*NHL2wbG07ExGYL}>*sHFYl zyN4#8d#aS9y#>O*CdT~CIT z2uN7e|JJlY5=*R{Q`MVoZ`m@20WL8)2|!-DBm^kyIVEO=RLm^P^L&6Vb3dVFT)ov+ z(h&X=Zzx$2HgUceCqYb#6S2AjzH_fJ|xd! z!H9h7$&o2hb!wCY5MI=tl$62glBVv*@UnF^yuT3ONGvQY9ySU~%` zS-i9Rnl&VFKwTu0!T9s%&)Z4Qr*b<_$MN}zfw_5TS=kBKia7@@7d?L#5EcQ_QGV)h z%Gk_=B6b;Sb9U;Xj;PyvUvBHd284!ObDoJsfs(v@zn6%Oa`@9Rw>JH#C8v9jyiN{@D}UfWx5<`yI#nv6b$gYmf}_9Ftb zIew||VC02qUn??}BH*B(v;l2R+XIflc!?ou+$+<{yvUPbhQ8g_n69k)axe8u{}v+3 zvm%oyHSYIiXBk$9wyVOz9S2S1oj`2$RT(x^dnu~tKMFZ~wlFa&XS>sPXrCKi!O_fXtF&8w=m+~dO7@i@MEA?DK)IL3^9^*3cfp_RV zxAqg?EBs>eySx(JaiGENk^g;$eBJ)u75>WqNcC<_e3<>{{WH`m+7-MJBvn6IN3Y3sXkz~6N}AZlEm}{`C59Nf@j#{JoU~W*2vf<-YUQqHz$XTAfnp6OOM|z z^*TYrKPuWDn0A!%yAd|d(0(Y!lnXM-*|>m^3M7JtQlDUnb+V-v)`aqK?AX)hic zlglBxVnTM*K+i%3Ta872uB}&}5kaG;{J=~9&V*NN@R@G2x(&PfrV8v$4lD zk1z91dE`q178}_BYrCqcH36s;DXHs9cVKSEyu&*qc}WdW94#XQAKtG(Nm6eRjNQnNO-+Ca zFHvCwa2leS^ys1HW)~Z_jarcal#+z^*GFaL2FY|~)uhwq(^l8Vt@wuws|Z|FdLjX@ zW5`E?%~-v-LxX!;6V<7K)4N*sn#{vKdDA~DW;X_yUcmHocAtJMP}zZ*17Ng_otcZN z_v|cl2!XvWArnft_rU%-h|%ZwlveY`7e3P?)933L-~R>{gy~GLVSQd(dd8N!KiD%E zYj0Ak){=Z5TeeNj?d-~2C&tUMLVt2607Hw$g(W~S?9Us3bwADvDXlD*9c^OXKX~qW zZ}H7>R#76R|%JP5)lzxk2z%nd4Wv zJXnlg_oIxi_rV{x6c)#}>;w=HavWv*CIjh%t=~uqI$tjtTUE6?Tp>@_;$h;vAoHZP z{s$`pQA8)4V~__CV$xjCW7g$@OUD!?bQ`}70wUtr=%f$^MFq$46BnSC-rl@6E}En? zf(>R@?L+X9Ia4!J8!K3KIy|BhVluVscG+cQ(KO=VgT7BrY2H?IVF8m`Gz|2@aVBpc zV8CD%kY=*5u*rIQ!a6UXnp<`bUA|Rp8vxh_UNc}wJ!)+k(0su3kF>P(-oeLrS$X-n zyNM5o*&%*Jt`zG5!iYBf9GZ;?Avk0kN!r?4D%DzH}mGnL;>FMIZ>7A z8~>_zkQjQ(de`pSN27m;YAy^^0(RM7f07+F0u>XSA8{H&+lyBSK?(x+J9#IXnethw zm@Lc!eK9_FM%|yAm%tp8ev+pn_iD2@evF{c|8;OC?ohs87*|9KAwsqrWE+fqsAQKVOIasm@O%9Jf%m%J_d3rx&$;i< z{bIofOmldaB8p(59|Jz>ucf8g*7(>zHjJ<-_IrynP?adIa`W}Sp!y~n5O8=~Dcl>B z$YOvPjL!hvirX^nXF(j|G8Fn$Q>a5TLIbGqQ?-^3Z>#EavlcJ(k|JV;hRjcIYy$N0 zy&|b)TEBQJCpER%zfX+~=8w5uU?7dBB1xbldr)v4b*g z3xwmiF5a^0d2`MXCiz^)wp5EdlY4XZU^WHpSy~+MrbM^5);uF@cH%1HdyTe|CVuIQZ0w}{yD*p*}D8t<3MaEF|E8ZsHSe4UvkMzY87@fhh+ zw+n`=EpbsNnS(*jJMrNzeo(&G&K^u^U$mJIn5+8V~gUd9k(@)NtG)x@!g z^kQzfaq00nLu3I7-(1Ou&KCRfSn%^?Ziutw@@HX#PDw*84t*qTeN9AM`QfY#R+wzV zW4v}C-6x}PGn?ZbPV|(M$6gzhE#=&w3W-xZ1XqOdpIh~pb7j}!{(~4QuNofWv$0Ou zD-y+1Iy8IF&8s60)|8f)+y105_QgL#;GGtC<@y5zdU~}8gUO?#Kiqwn35o3=504ji0s|>i+5vNQLc`Tg8p)k%&76I= zLP%Fn*UKSE=w!no8isd|G%lV92k!$_1I+QAG9{_(>Bg?}H)VR#zvCjEq2sQ!gC}t=-+-DftFu^Yr;b z4DfXa0&`LaosBhUqIgjgONYMi9egmkqveCaz$i_breTUD zD0!>qJi_n9a)cvSDc!S{_YCx46`Y!E} z49Ap^c+TcOZ^j0xo5%5auZNi`*?x#3D+;>mIYD2(PXPGe z@Ce)-&2wBVsU`SB|8Q$f#5^^f{C?(j_0oUt^Y{_`GLc6)p0~*Ay9^M!mp!Ly3ct$U z?D|l>z$eJBXTlw{I@0naiiP2cfJ%^zvTS_iA}}HR7H80a{fLzwuJ_Fdjn7OaaweHM(kyGbX#JvNTz69Ae zrI24goNL;Qd{PJ;(>oFg3Kx?GhC&j+=)Pl-PNJ&Y7u4K>YxCnyjn~Wh1oo zvZFQp`#YCCqVSnB$$daVY6K$;^z!)9Qk2L~?0tv|O#&=6sRAm&4`=^D5Ss*!^gX_b z+pzO40Rh}592L+1HXyQq`E$`sG3b|@_9nsj%&}u~ji`CLKWX$~*7OGiq)M%0V6ehG ze=ZfT^fP@O$!#JwoZkFUiQw5t>B1_`7GnI0o%xYb!C9wz&9c*DDT{e-_a#b#SIE$8 zr&fx9RPndt>Q#L{o@`&+fxv*gL(wP(Y6-v&EOhnAngr|wXw{?i7%cX_fuSMTsR3db zMtsl^vM$WSGY2|fm5qqf?ixbMj?{Xj4-Zf%l#uB*-r|xHtJ>{F795y^cXEuL>!@CN zEiaI>Trk9Y{`?x9H&ibYDx&U2I&!?)8{FITekT z;pCSANq_y`kus`QZg#fJjrvi}zIxy6V)K{FRD_9mEx(i@Ib>9BuWw*L#slCixhBInG|@=u1iwb+uqEBuxg5i>L zWYK5*u-*%f@e#sC!q1K}q})~v;VX>yxfqvfF)3H#mt1&5U9J7;lL)xI%oirZoY?pm1dS6QE~Xe~7$ zCmZc_(XY9myt2Uo51=tY2%LXEa~n1`Mco8so#NX}fle zn2ZdEdP&1)i61>AUSmy@b0s6TEuDJC0E>3so?mfao)XtMxbk_gVlwZ|8^@=*CMfy! z-d-u^0SOz=dSIva2j+~{`}*$=Lg|~~u2zL%X7U!p>H6tq z`VD>`c&k@ZTpmZQMLK&ruX0jl0JI(?or%Xg+yCn_Jfc8HE9y-w(kYozW>$8HudA-O z%*^`o+ex^gC6~y|z*$yXQe{#+14A(}VXT4@mWqM?lRy=o| zRbf`HDD>_UNeCPFK9SJ=ku^GsG!d3{T^R}>#4x?RtW!N^(pT0SvpCG&*3D1d3oVu! zX`LRlKFf`&mMcm$FXBgV$GgkoiCrO%U$0t8oIaWB=%i-2DPxj30MI__?%09`S@V#= zZ9+r4Nl@t0sm-IziPE|GtwnM0{ah9tic7m@XGH@8X{6oumq!v;X-F07kd&LSAU!x7 z%G(0#=<1jW{!IJ$DvbV4^iUgMi{rr&UyjBRg_V`b^vH**klUB?ohkqD z*~O)pVA(WuyxoJK&8>x8jJx}b#S)MCg@w?Cis0hX(wW6R4exD7JBCjM>a0jA@S%W+ zG`cOPsBzohOzx1rzv{p0qbP=AI?{7h_vqCKfhAgqeDH%rBJTMk?_*9g)#XF_p4#fy zrZ}r?N?*pboqi4t_}Xi8;RR$~rM{N{Ejx)wEjM-#V?{ccx18wg8nmGCdA9@u9K`gM zRM5PI4A&XI2lpRmtcPQSwX9=_{U_v})Y0Q5AdzujPQydsxSp5)?fmoo#M%Um9MDhE z&d!mR;TxwF`F#i;n;2|D!X56X;9d=A8Gzh$f_U4<;MMnv7>5W!z)3*)ScOju7oZD( zZj;kpULJkCJ)?z6ey#R#gjkyq|L)W^FI9zJFkz3bUuZP?MGOLV4yihv=>cI-ubNY> z_-NlZvfby*$IQ6TNLv5LZ;rAM6D7B0JOL@Pyn`?S15)R}))wpj{(eAU(BPzo;}b`e z7z?|=%<^a24S9KaMD`0}`E+XKGGgDuh#HnrFllBH*!J}#|HcJu-LIHrp-2hCliSDG z$6+2TtB#T~h8^bLU~=X8H^mDYB~1l2tzi7TOjrad_8Sct^L+^VSgF zYB6S(*gM=O+VTi43qh0>E2wWlLw>Y&>{%)p!J187j~0J5XpjC@Fu)qF=o;fiM{Lhn zW&Hf1?@a!O#Tp!W-2%S_jEwINWPZ&+O@uFd8yH|aDU@m$n?WZ#CN;}k=Q2vohuYD0 zTK{!>byP|Gz-OT>ubO0Ve&cjDqwz7Cp`_v@D*gm!qFF|=LVsbqUk3C4%DvzsCZlI$ z=rt`~UQIvHV`M*O*W{y4O5J=HG>ST%a5{7?kL3t@O-@!}J zg!fp0{fu{giDJ3;aST}NIwn&)QRIGkfqHUJXYB1t>nd#yCeugi$qVl7O+T53hPIdsSb8@HM}M z++O(xU$sZ-4!7(9 zRnXnn1C1^|VIp*0{x|sbPVME-=VD9ZF~9fZy3Y8z>b_-o;a3^Edn}?p59QktQ$8&-%*KdF3{KfY zGNpuw5#hHd^>-$KDE9)!Ue=s($(I35ZCxj^iM(00lI{s|QAg?L&DZ!*a0)(PlR-sfFQti*`&SYD>#8F1U*In=^POWti Zk?dUud1xPz1w!o%2)MCM zaB^>EX>4U6ba`-PAZ2)IW&i+q+O3&uaxE)whW|5)3;}Ty3`ZcS$_z66JnHRp{Z5if zDt2w#yS-L7NUz=tPWM0l{iT27qt%*I>pI68BlFQ#TgCWc=3gJbzDsW3*YD$&pWn$} zucy*}eo?qo`Cfnhcif-X8`t|EZ_uCD%l_;2r03^8^XEoCzxeHjHJ%)+@SjEg+^Cjc z*X8H`Y5gqK#-FzH^KY)af8O-#SAY9|Q;Ol8`_+=lty1&%wfuTfWfgT*<9mGN|0S+u ze*W#c{%ifx@1^^P0+#vD^~>ky74?sD{JPO^+ZunX)6aG3|6ISXN`GFFl%JQBKYkG( zF8Sp@fBdO_@A-W1{@R_jkto;r+0<`a{C?tv6t>mxZG%tcU*~82J{O-eTRwEn+3HU8 z;mDFgHJ>zbF?n*o??(c^dnu>X@7kuSLG+ zu6P~~n>^WtR~F%B^!Trz>F*c*ub+Rp(R%@NktpVe^sm7s@WnEsbMci|(GYn*HSJ%0 zCH?xA|Naq;t#q(|Wu9!%^8R&;-Q^E#)vwPr_ryzE{xvV6{&NFL#Irk#kxB>hklJNP z)>2f2I!^jyti04#HXV?jdT~}+pp-##W{GF>9+mI4rI%lqp%)s_>!?a=7!;b-T5p4S zn?T4?uf6r&N7p|4>Ne_VqmMCk1W%b(ZQ6CH>eQt+>uj^nF?G(lrmeCnG}hA^OV?a$ znek9-#=VT!8y9xjb+_I3*t+Ll+YazK?esH_o_W?WFJ2uANIqPB`f|H+fzoZa-*NTM zyRJR3_Udb|zwz|Vx1RmpwV$s3wQGOx+<)IS`RQ8u8Wpe8Kf12tr4-!y-IedxRIX?onIUtaJ~1sT8Jd z{ww{?t!JgP(Y#)_UDI4x{-V08_Q3?7@#bLPmqloQdyWT{I^;^*2jI9~%edX&UP=8~ zEZ3WYwkO%0cf*%nhS5KpgZ05l+^)nEPCoSswK)8v53^Ioebdo9cCWY8vG?G6d6VUI z=ylQe_Uh4<)oh)|-9|5g2Vi$qb{PYhE$QIMX0r<$@BPGd(%JGr(YAHBj}3Y%<%YBM zg?ik*WYI<5bY7<~XH2nrc9NsIaW+*KV*$suqPH^#Y_p%~_DgTmigXH2qtZ>UR-j&B@2BQ~v*zJL9*Lu3vr~5LrUTB$T?XAU+^X+;b zvbK*i1oiDLE7R87SsaoGFH;=cT@mf&@l1#G(LD!$zyhcn+OEY|8g zWsSSMO*vk)VQqBK`wYScLl5@TUW_JpnB{02FAs+9F=BSCpWq(6Q)f~S)ymhJ54#Lg zVNL-aoVMy%yGmMgXi#Ln8zOFEugacYD1$dj<2rlna>H>827ve0@OG{VQ-XNjJgaNk z%(R%Zb*RG)C;=1oUfY5@=I;CobgJ?cH#Wk&uQGD*hsJP%cM#YU>#*Z*O}0zX%PZAr zi@muTCDX%L3@a&~O`0=1=P)z=Y5_9s z5*{+l1qZR#^nlhuv_N+`YcLRvNd;XD)CnhewJ3MV>t=~wH@2k9Ztimxr3Zb9yCp~H zAV%dnW1Ja3!=o{#t8|YUeY|zU>@S=yEYS|C!o3*RO&1H;Z_-W*O%i^#4-tUCd@SU6 z1Bh#EOBdR`wXjVYqvb`xdE%nL>mW!s> zH+{VL(iFl1b^~Nz0Yv6I@yx<}gT5wZ(iX4}@L0}-RR*`E$1gSz$kU61~5r!juQ?Z-~k1Q#hmy zfmYg7Lkpq51aTle$kSYu0jd->^y+) z4p~R+?ocjMee|4Z*>4k{b*f^G;Yq>$?cMOwL_`*Xw9k&}V==U2%V7jTfbA@ONFj#ef|nYO za4~}Wegkt8+Uof-dKFgc=nU5lt!@tGU0JQ^-Ycfpp>|9h1=x``oAVH+X%L5ypJZHx za>XF;u=m*yWU<#HqC&lLQP1=4Wu$hi*28tL3CoeP8FVxP2e9`^Bm?S&eu5W#5l(j@ zoCtgSCYAA|%=mV<4FWRTU`<$IsD9_@CngFnyAxg*HS^CYOi1#eT@z})RLI~W3HM5Y zHWM($txz6z6vSw-*lbYc!Ekx101o=?z)5Vn!!e?pq5a$qi)IyfKGw_-EeC~6{=)Q7 zdT4?UKI@b%WsK2WU!!vZDg=OGHac&Un-nk`z z3g=+I05$Xy#+t5RbxvjXie?cxa4H6o92;e7y9_;5WOd*i?f{CbC8e?(r-qNEnM5n5 z7gHS`Pa|0{(ikeRP8M><9uSa?aN$71McZ0v{HT`32@i@TSshu57SyoY2r9D;z2f#l zk{Yj#y__*f5X=!0j$gG9y)(^fVpOU-<>m1RDrAN6gZ}P;Y?&t1A3+}g47wN@wJ8-B zNshDXc<)ZJCx#~x->846n03(ZkPgE9c_Qos(TkxYUmd)>#ZW~SW9UX?84L~TNnSh; zUAsUBSIq#uP1R^Vf{S!GDuG=fxkMb7k)e5pkA2({1Li0N6H5-nTGx}fU&O{9p0mL{ zwm~6~!;d|MO~^z5yDruVHDaNmv2aaD47m@o4+{xVB-K!FT%}V2CHulA%XXM-2=3r{ z9NRNATi7&+c3~OGA^4LGV&CZmoL)R3g$2tw*XAoJ|3dH7dWtfFbePM`Ka)>iAIy%^ z>G5woYvVILc+EOrI*}hBKWI_HN%yeP)cw3yGm$ym2^+{^RwrIB$kIS1_6hgJ2qB7L zJW|+P!Z}=krxau#Z6Q=SG>%{2nu#J|Fj`owx3OePA{()7OffbX(l-)7V)QPMfZaDt zWMtJE(va(+L5$01y9e-MYY}xw3J`0IInIE^;&w^H+0&u*>=8dQ5F8i7Y9BTLc*q>Y z&rZt+V@YlUVg?e7(eNApV%FmdidB@i8nh={MVDDz3efWivN`H(FMeaOX)fq1|W zUt`G8!~@D=8ifUdb0ygIa6Bjxa2)-rE{24Tk-vHfUdUa?93kgVMl4o)3@ue>H>SmL z%9|0k66b#0yH7+9ht^<>RT#a=QrZn@b?BonZZV4OQpmKC%1a*Tg}!OARS3*>@PN7-OPKm*9(U8oOJ zPXeWI5H%SE00T_3IEf&11&{V94FW;#u*yV%m`Y%PMQ)2@lTK2~4@-@L;~li);K|(h zp*W<#9s^r=>jWc^M@h_>7a;pE@zE%w=s^TD5y`-DvvBfriK~a2@{1K3hOL3X=^m!U z7M25d*&~;~K^Cy`$pAj_8nUsq+6q4W3g4I?>7#2&v;uD33#3Bkt`a-4%ff&k0u$We`1{MTn%n@u%O2Gqc zRv9OjUpY^-pG;N<@^;w^7NU|AyJ1cZo-Je~#evBo`m&hCKzvS2&F_D7BYYJTL@ljEar|}z*xZTaPAAKkRU+zf_1v*TT2*ryZ{2RPz_@2iS;QF zg5V13LKi2cKqKaxRk#~kL_aA+$gm9j9x->OLn4kAz|j$5F5=}VE&K|2W}IF!?0Lnm ziS-MegR%I6CXXpuf9c=#6lXPGsBH&t8ibo8*KHDeBfWJ6`c)$AuN-C~juIJY20Zo- z71c-pr{etm(@&&SgZ(3M;%meMgYb!=FH$;niK5SvB6asTDCOSJAbuAbSYxwmDmTWtYl*m2+%O=q$#PmvzwY>EGahN+Z z5jZ~L^1@+5RwPRP$2QPwnhFU-M2QedawOWpg;%P?*#YUdllx;H9y(4O(C|7k>;peB z;5$IWveVV;4NE9gjcbBiEEzW8jg3VVnw2^|4T9BguO94^6u&?1u3lSQyD&|yU&EF#P2B|ID|_di&)N?3v3VQ)NE3^SIAKKM6G z7~+G)k#uI96F@&b?7zmrDIz^#L_%WeU>qz2WA&WxcS;30L%n;&N$2>4V!x&-Fdg=Y zB8A|gJ}Tjov0CwmW=8xzCn0Mft#Jq3M&DRgo46Hjg6AxO! ziWP)w!Ystw>2%s+sOSe)`2@8nJyNHTnHtTV*I{}mB)gegoXv?sAcXZr)WId@hil5W z6yE*`cg~yTU=A1?)Eu$DiS^*X;Wo&$E(tQDBpH^W;F#GsKwitJ+;9z-g|u@yI& zv!UR<1wd+$-JQUiZ$P9*;)%anK(^!Z@`)~ATo4VQP$SMR4nymgl2{ymufAR@>=et{ z1J$$dP*(R52kkXT6m&f1+sGA>T=0jBR*##)&c&Y#vDYzKar$9#(r70R6D*p{?g*n9 zO@bHW2&s}@7eZ5-!#@r1#(6eiu3EX>4Tx0C=2zkv&MmKpe$iQ>9fZf@TnL$WWauh>D1lR-p(LLaorMgZbzeG-*gu zTpR`0f`cE6RRb_Q zA;$tL&>*>fus`_St(Bh`_mU!Up!3CXK8AtNF3_ks&iAq7G){ovGjOH1{FOQ|{YiSQ zrA3c`-fiIGx~0i`z~v4w@MK7)dd(G>uxsTHaAWdB*Z-9eC zV5C6V>pt%exA*q%nPz`Kfyr{M6dI6&0{y7c};+c`b?2G-G0mnxZKhO*Dy#$*U$tqtR&WQDcm~AVw@i5U>CW(u;K1 zh3z}#`(xG>5s;ANd%wbY=6P6VZkf6FH>aOLf*?RW;yB*SaXkOn$tMXTBxNZ0?ZS)TiDgF@v5Wt@EDc>FY<91G|zE%%H(?}J-%gd?g;QRlb zH@f1dnEkQ0g1wv%ESqL0WB%u?Aj|O+SM13stOyHqKfGdk8>Ie0m?D5h-<{a-a~#kl z06_DcQUIx>@djDNmHRH-$}FlZSL@6cR;7@J2f9xg6xPnuNkWsYzDsqTVbz{kL4d3W zRk11A@#@zC)15u~jpJy<54F_IA<$cgLLU z5`n-Sd|g^G!io5MlUDArm@GhG5QGvFyZ^$?{lCTMpIGVZ=Fo(Npf#B96_kg~-e5TONvm%(d-pp>E~iC@2RJHae_D6s z__0g214Cho^l?%7xjWR<=}W5W+ymX)F#5j+rXOrS%Uf8mssAho1h!X5o5G?1f>O#f z594pvmab}fBe-_VOGIR`J@$t&qJI~yBgi%45Iikl;?@fL> zM+Ioc8_MxW=_Xlo5XV&DY;IC(Dll@Z+= zg(?5=h0~kN*$=9MJe>f*!A@#}I{?Jp$^H7L7?9H9U0u1Av+>vXl&uRQ+Ifl%QtlV` zn!ov>+5o`wN%xDJg6YgAy=>4{o=wX8-92tWWwg$w{9 zf;_r)@Og1~_gSMOUF_vR62+y}^S}8`Y%HZIkkEipke5?bM4-}MJ~^rj08j)5lE5T^ z5E3PTR%ZeNMTG@?^2)H6Ms^SC;N$9`X!K@=^#~Rn7nK@)k2&8Bd3XEOJNXUTJkzzE zqnrJ-!4Yqa>)koP9RvXoEdAk(c=bV5?cjH}TUZ_dJe-tc`*w75RxO^`uc^oH6YA^l z=Q?9(x7lNQxjHBSBf#k6-LU{Lws!}>803<5N3MvCcb~rrG6n?U`N81`wPCU`Zfdk} z^+82b#NQ!wP~yEJiX=*^Yxl?AzL#GP7y$^QdW8VMya|2gO|bHeZ|9Gd4qpZodGdP3 zV*?n>S2n!v?r7K41f|t=fD!1;Ti%-3-pA#kBmjh9&-oOZA|L86+|VAuE-Lw@Z|98q z;3u8dpw<~LBxhZ?lGeGC-;Q^m?a|&_Y#HL?Ty^v#+xU_5HYH{h#oo-ZBI#U;TB8Sm zd6W8Wd~@7Wpa0ECecznaSNy%@t%;sfJ};@ROU*3>fN_000ti6hxZ8Ppqgl#OrxP-O zB*1KWaagxDL^kJ;dQ6u8bwcp<;mcniiUUA^05AeGyKm)m7kewWTX*Qn(QE0Y)mm?7 zrKmNU>v42(1g*i`%nUt3ynl^L1sT)rjg9>~`+fELs4ju-A}gBl!A`JAbA$QPmIW?h z0RzJV^7k$I>R{ZO{g;5_K`u$nEsdPF>G^44WLff7~rN9)DM_H=Rq&47-~C+N6G) z7LF$g405SLE)kpQFeV?mMV*wg&u5MJeAb9Pv2n@U+x~{B3zp@qp*~=v)@$|dZ8O@P%ze~?31_A>{RwbLjHV*LFtl@jlr?4z{ zCNXnlWc$h*Eg*Pj>VP&lwi2e-(h5G=ezqxJ6-5#O?haQryy0RmZ(!+fTrlq0ft?Wo ziXiGc_7`tDkyBU!6uES2)RJifm2wF+5c_t|8S&_5i%?rJYz z{mQVK6Dw^Ro*;liw2JD{tWiBj^bGbH_X)@)Won&;^j7`trOrQ5l-2A)pJo*jJ? zGmE}Ikx*2j0R&#I_8ooQ+Thrv3Dq(ojv75I9G_WK>G%8>>3QW1dPo^cDP!!UOoQN4 zT5g%ht7Z)A>Zp{F1b&(^);N{bXa!#Ib+zC1-jw)_bHHE*0w-n^1Ax6kqLNz=dxb<4 z1>YRGb}PHYY_X~qH$+wkKob}ts8H)fhttxj1D#cJC#6hAKNe;i85shCP@~pvK6c%9 zvQndar-cQ2_`9lZ z+B|=D$&=Xedof_3e zE~OJPit3N>>Zl$Ye~UL+xDS3j`{~|`ZVrlawa&p#n!j%eM##-UQCX|cEvr$D`#ApF zIXyx=K@fnz9baBIHL43q;+t6|4Q^@Cv;7utItd5>AF zM7oR?7R(mjWDx{`<@j4UW%@e9AGdQ3oKMNhF9Q?n<7)reD~)gVAL#{vfYqCfdXqtK zG8oL9i3O7-IwD~6qVWJQqG!mMUhTnXHtCJKP9$tTbh)s&3QQI-n8Xn}`nrBQf2@q5 z1QV;(8M6y3>gx0bd1WRuOAvVB#J*r*!OWU0oRlV`BZ3BY4FbJUqts&k;HV;dHmV}`|*wO$if*|bjwg*0vZA`d64CnzQDEpumjJ-&gp9Hpt%c*KMt4!l9?8KYSJE^(BSvUa|4w!?rl>-_<8<;LcMCKI zlc$5-g#MvZqQjmv^&X+zY}p%gGbTBws8UT+L>E8z7e{pK`p|VRQfhAL?wFLE2Q>(R zn}eczh}ZBQ!5&UZTLZnxvgP>ob4gixlR3!UadLE*k&*4iOB$_l>&b-tvYIbukCaPj z07$r7xG(l*QDq&0ke7?HXM3-q-GiM~kCga*F)7nm?MF0u{oA=cJ9xKYvGp)*hAyV$ z41RO7Etg8pPXmKibfzsQuCJImq*2yLVBC!U_kWT33A|wS zH~md0g|Fv~YJ**Cg#Pql-S(bO5&eij;8iQdFT?`O*+ODf1tL>6f!d?v1@wPnHv6#@&-t zTBB|G&~{VXg80RMzyXEd`t=4?#MDYeB*D{J+y zjqdq(y#Up91`Ep>Ev(L9q9{@+W27{tl1u(g4a{ z@9yBBkTyT>@%X!Mtvd=dSz4{#`{C687)(8#m9LME+;TL&b>a0!>C?Rzx12~QEU6JQ z;y@4(O$B*4CT^N*uaN%5!F5LSgyp*{>-4inMa~)5>+feCzg$hb{{Se$-O+AhztG&0 z>JdGI|E|Gr=alw&^LvxY^2qm05ia(Mc3zI3zC2vaKz%CR%`acD@%V$P+I4eA5APY$ z!l~c>{?zu1~<)%}Mp{qGk~D6XtM@yA`O z1^U*K2Wdk430oXKbg0wbDaXi^ZsWU^G$Xe%_j!3)3NwQSmq7Nql))Mzi6*&nNv_&t^taW(o|D!;7d1i)LLU$kb8iKW24$mMq!1)#J2NttREd{Ku#)o(x!O^ zGh0(<_-y~h<5$yxCTD*3vx|dbL_G$1E4!qyTnzwS0^I%F9UeWsHSemC6}Vqinf0K` zWaa`q9KwS<#FP7CZvjbw(R^Ul%&71H0YGOk$0XhF=<8a)Z>`pNB`uHRg@_;zA6I(; zfFSVERWqKPCGlQSQH3VR(=qJf6k@(F0OUNVE~>0GTP*(Wj_te}sng6Xthk$Bu8=bQ zI{Qf;*8{KI$*a`p!h$>k>#gl$ylrH!_9J_>FIQ_b3M#qA+)_R+%H}c3^TOui*MTI! z%zpIp&=02%vR&%v_tyH9}lb2K8-ggyJhU59wdoHfra>`btGIdO^ZEsH!^Q5Y3 z_2ZZAx|(z!1W{-R5|zC2$TBf*^q6}Spt@Eki&_E#fZ4M7;~8&^>%|MgmgCpwZukw1 z7QhHNKC)}jfo0Fx;%%Sr`~8apmo#N+ATa;}j(cH3pWW|GX`Dk4*w@Wo{C#ZI%rPJS zbTa-fki^=9al^U?V~k3wYCVU32olCx5c9qO=+?29Edkb3L(oW_XYXLePHZ*O^%l ztTGt^aJ-kV>*+NwbqR16XSsg2aN6nvdD#zuB0&&<7p`r4vrniW035n>XWHrmyxs^H zgP9Ev5BT+y83FDN&}dBG)m||m%s)Cj;K?KCX{G4AAl%3(0ssd`#m7w-t1S5D*!=ZJ zt)VZD1BK+`&FrYfTg9t-lch|p13M`&l=Tf&NMlp7XRJB&Vi0z*k?Qe-b=!Cn4UElaN=YiGY+#h>eqcwt^)Cw$&qn;*GZS>2E>7X|OKw4h8m?ZAu zqy$kEa2yB%C?x6mWurgX`RHu|0Lr=oV!?ugofJqy#0PLZNNBazIB~_Ea<$gViDgXt zm)C7La<#$i>N-RBIUDz$PXQSNfI#2@!MMJm0DuqzJ1N-7#ALTCsd+<|Y_Cx30ApY| zpval#D5^)I_}gtww7EfaJ{L-U~M8l~!9fNB~+Xb9b-< zfHnJmpR#fft0?LT{2+NmPWYgpC$f@vxurFeR_u9DU3cbcI?oG_ z%u+#^IlTL{!C@RH$Rw1MCKEG?&L?F7LJ>h87uLQiWvEFj_nt}2Ja#U*aKShqR}~V6 z=J`dF2Q2zmze$k=Cp+2R&!=|wcYS}` z+1+Q8G&-e`Vf z{n6d$lB=rg@=9yonKIx^VkT?hDmA)A>y9q^=J?{N1LuzGB{I|>&)mo_tpR|@_TC#8 zjPrF<#ithzifF{1d+qBZEXyfm%x_=4I=E}l4<{35fBmb`%r#+A8b3j=I2YrMU`8QUw?gUuh#qO>+~j0Z+zs%q6mzUMsF5xRN8da|JAg7000xa zbk{6s+a#sgMuZ}*ij_N7yN_u`-{h>Pex{Qo$Kc-KI z!Ciw00>>m}aU2I0_SB~@Ix1yShWhaNXyAEZs68>s^+yi&cJ^~~=pN!VA~K{Yi;*G; zaqw|{J9@jQzBzJD%#@uszW4B+!Co$^UGGov_HYD%!*QvPb)<++{%-BOoSq%f+0$9s z@+GxbNI?LedlUzJrdvBf5Mq3t)V&u;R*{ygGv!c%g%j%euLv{oEb=+#DuHbzvwXzpSQ8tpk=@``U0( znO-_Aio_Uba_70^Cb?;0xhB7?1~qN+vmC#2R}9os>0Ip<5}I;%REUaQewn(dzPTke z0Dz>_n76lCoiYf4LJ9((*Pz$K-EtLGQiQ7OnltV16;=QMl2D$bS6J&q2n;d?0CGxd zS{H?~{l1x90t5!0?-u+huPVgbxu`-@($LK;2(5}8XFaF_1X6|)O~ad6B`X@t zsMZ=xmZ9%%=P!Jgkyi#J5$xq8p__AT-^wWg0s+9`*xTr2qV=qd0RorT_o+X8eqUF6 z(IcrfnhPs5q5?Cqf2St}EC>RS#OrGgy}oL{)ksuG0Z{gX>Ly{j>wNO^oiWV>2?YT- zp0Dq^>J;Frlrb`fVw5tIg-g1fXEL))V?0$<+XP;4@o=awaCu(vZi@GyMt(iHNuoX= z2;lL!xn{O-?oHH2x15-V)vBg??hY};(7Ga-26XCrhOt%;657wrfh4d-XY9RTv&q7O zTr#3(NT|Q-#gv>B(K>0GUj6!)GdJ!7O?K_zJt#b=phA5(F15ir_3h->x07F`M%QER zru?!Ru&}4DXNbGPa!+}?Y~S0G+ewrUEjf|A8r!=^oK8SCZOvr1Kw}#R0FJII+d7ME z!pB9`Busfh5ICL{TI@hIS-zS*N}R7(N1wRcdB9LnZ~YLn_SO2ONCjnTA2<7Up3Wje zo;SYFikZ)}#*0Q^5gEmmn)>G>0zE`jZe2WKOrMT_uj?)<+<)(>L}1ACNm)~(!vMf! zW|K3E0ib<=Ybypsk2`l}tT_l08hCza_jWWzBxDqtSq?aUeE&|OKWN?ID=C@9Vtbk- z#5$kc((3vzt$05@JNYGiHx~fP)!HMM(?mVDQloRUli2{_tde^=+2xnji27USK)2Mq z^1B7)RkeDzreEpa9ehLq*e%#AZr$uZ^}K$r_TuIpx|I6$oYD0!g?PJ&o|7q2U7qyL z3IhBM&UWDVoy(r<66h{EG8a$kAD5gBBys&-q2q*4Cid&(@9t1sr8$#uU!{=T*|xBM zXFrm_ERY}Xx{#b%JbQE`N#c~OlDyLD?-x&O5lk_*caW>K#!W~6=-STH$3=A~w`}gX zUeOVOG)mA>Q+hYQyrZ9Mo!(^Y^5xH$HXMn6wtuJI zp}q`7oV$?~pI#vHi@u%wW{vE*XwwO^g$tVT)!Nx3+WWYq<(H}J49jOcwr(yCik{uu zCEhE#d^^TP43eR)_Q@Ug4u$9}&p?uVUH z-a?p4X!4n^LBGV_0*ZX?^MfD=ltMcFnTYTpk0nzEuHJSQB-Ea>NqbHvh!Uk|7yrKw zQvf*f@e30_{0RuWa@#r4G6ORk8tgTH!k_q}jv!e5vU-!{-v=&08;??X~~rCbG|j`O?UIko?o#CcmXg1lE^PpHxE+)xV--L z-U~LF&6c>EIdO?%;V$YP>>1+i{Qbi5O-jxZAwfwcJC{y|rfQ)uLI@)Q z*tO6Y|d!8gu)65wv!SH z2|_5u-!0hNY16{-AwDjMBCwRkQX12wyGy~!Kqj|BjwjJdh>&39)GCJBtCG?p_U zp#NZ?2*0u zo!+6oKs?-cIa8xG0KmAup@j$DiwO1vY`sKC=n!uwQCUWfegjg>vusYULeP?|Hj|kx ztI@JNuaq**4QlxT0HC@~uhtqR6zS%uXmod*|Dc-D%vRUx7>aaK$)8M^s56*k4AsmR zewwoG(bFESw7O1jvUoZv>z8-iLWM0)pfOuGn#A&_Wye)&bQ~{8X;S%whe&HMX>>*z zL)ZUo8vs<->Mbl!Q$&OFMw6vnqhmQf@?$Vqa{mc^JZ3RbQDQ=IV&4(}B!hou*00000NkvXX Hu0mjfn<8~) literal 6501 zcmZ{pS2P@6)b@vAv?2cL=mgP%(Gn$?sG~&mnql-dn5Ywd2!e-R0&+plh+FGjQq>Q8h0DxQ#uA~b9;Njn&rHTJ{|L)P@ zabHN#aFhE=PxZeG&pvzVz6(I)srEvdXo-yC0T2Q^VetY07zosqp6g@g_7OzJ`u!Ic zYr`5r+xrD=XIWR~`cX8W9B9m}4!I*41zapm_(I@%|5fGCDC;X#0r0JWyBHVn1YJa$|E7VCN;rgIMAYAefe(R}}g^Yg1+f7DUCje5R8y+7Zb zX113iFdyX&IsPq|-n?)A_cZb6^VvSrKgUCt&`uctJ{jl_KCmA+iT^?m+hz<+cyyI( zkH_}M=GVsikf^`KKm0tNCS<6-bKU%OzII;OP@}497T-lI=s5dWZMF@OX{6iRz2Zh1 zx*FA_+GaLw2Bw4Y{|wm)L%m(aX(w4N z`Oso`;wLxjfaR>;L*{S&qlTPk$!C##2_QrNvB=6%uT5#T$NsGv0V?oUJ zt>v%}W#Oyr6FRu`#%acbZZF6O;%FwmHa(>Q1Frxv276Dw!AW|~ZoaY-Bh!+jEAh9T z8i_1NUdqSd7o`|O3H%zj zHB;c#>REYLF;8DRAps|M0lrW!0g-@Yf$@jE>5cB=`J12Gp@Xd7FO4o&gaemCcI4)) zo-PTN6a}~@$c9{>J8Dsv-dJR=Htf%E8$gdFy|)*om^P$X%vZf{lK6Hf4<0X|XyEq#~b*BRxFFO~2s+JdOSQT#g%W5=iRr9dOg97}eQ)As=Wc3CbhzjL9f zU47he$?q~R4-KZdZJaH2H5e@Z;?U4rAc+|qShE-#`&!mcyx{LR7mHjFwO&V~v#kU< z`naBH7JvyS4SDkuiif!S=Z!6i0*5^xgq_4r8dm!vIz;TlR2}5~fkiCOj`pkfkD99* zX?2eGl>Dv)lhh*~&#J2zaIP{;S`dY*1p&a?KhO#)aFHDAK$>4ZsOCtbGPpl9s`?#;sI99qvljJvf<%* z6I2wUaR3V|i)XBfD2Ht#HpXgCCVrs~uvkSs^;K-rL@G6?Kng7>vAjb=Rr!n~LT!u_sQAlbp`XpLY-&*cB z5->+{wwA>Aq2pbj*|kN(*p^j*-mag=6DmWMNI8zn8c0AG9x;@fA(B2+rfOzivMCo| zR-}PE1vV+Xn6N3|A-j~e9dg8ES@X5^^P0JzOv7HtuSfaMd_{wRQMr~O6)oU2{*NtU z!?Z)ooZx#wmN%<0$+1=0A#TjwwPL?k0)@6rIk_u4Y*)(ji?mXHrTmmIyn7IeANt~X zRZl&&DgTQV&Oq@XJrgg**tX8Bfu6a(<+A*zM%!G_OH0_xCFUo$nqWyxe@75Nu(l*{ThvB`gnvviEiY$?I6sj0c(B6k!ky)K$h{I4x6wRH`ag{BX$)?Zoql5>3)yeKM> zxGKdR($o)ad=7bjSUV_?eYd_e>8A?C% zY`7SQ2W2nqgPmm0G9CSyYu!ZW6&r6Xy{|f3>jL!r+frtxjFTRF#Wo;tphyTzV`Z|Q zmWMzNiB*4|^?7dE9&oI@PT2m`){SgCwOm0k3iyn@@;M59la8)T+sDASRXq+2pC@6X z<>h{EkSLl3AFt>*qrLn;aHhtPN55oJ!cpg%pJs+8yO3~{Mgl{>w@mf-8C#3@Pp|!K zllNH@TiZv|;!hA8ciCOz2%0}GOjS1Md&+lwGG6ueZ&V#{*2`Q+!v3i7I=>`OEX`9l zf6c}YK}pgm&V*!zCG;qVv|-LKsH0LVmf5%=u8$|^SS0;kY~wo561L$kJ50z0)*P$1 z?|PF!VEW+hyRso(-p#hF1C){fm6i@+TvP?U8@E(P^HoJ1G0oG<@{<8#QB}3}sow9{ zlo(1p!Y-KXKpckA7~;ZcTcntM=PQ*S>2pO(6-tZ~Q6Nb8;z(BDzXVO0NPr1eM+bT*z!}(7%G{bay=dX8D`wfeTmFjL8>>xuP|7% z&kASfqovJlJVBQ-Nhyd@A;V2ig+9{QxWT%=_L2l?e#*?zjRpd13p^DI;kZ^#DkrAq z*T1c+2K}_S4`c&qY3T?4N#F4*YRJ2MMojPFk!=$M41`ih>Ar^_E+H zO}J_E9mO6GC4Kk*ywCGxhETI!-eHyDUZEHhcn$H#9^q*rF=}vumD-EkRA`LW(CSG+ zA=qWunelq#;OmtVm+`(FumkpxNHfwx2s7}Ld_51kZrFg?*2$mJ%8;;1D1UPh5Yccl z#Rk!jLy~%ELBm`u-?#{W3-qk`)l*Dp^<7TVhX zpRr@UP>o5tA{q66G`n|7@p5Y=RG!x6G~QTm{HWZFfBa;?|U4Sji(gd3W}s=S(4Iw3TlpvO&MvsaQCx^>#v$ z1Q=J49y*K3lE(+_K2S=e<&_noE>N@s;ryX)iB=A#^7*-q=Q8n!Pc`a&!_99wO((Wj z2v_Au1TE4xYET*jkmD>t&%&NyxUR~iwKWPYK$!S8S@@-SMdMK7B6fmFZB$T#@;b1| zK`bTZFW&>t<8RdasYg7BYU_*bV`l_V`0Q%+AhA7>J3b!qYwH0U7}i&DY#$N^Vgx;s z=@QY$)9jURuMc!umR> z%(_mh@n^ONF~C{boV#IGT+|WGhi5;`H>kY?qt`!M2R?q(xEzMXn$o#xY=AAUELqOz zUqxyR?v7SI8T4I_E0~|*t+;=YOi!QHh|?$Ww$`+&-%1x4=ZqkfoE*_M;UXNZInSCK zs5Czv4la-SWPbFMl-}0b>ZLJ2Hzfr;<=>{5fj0^c5P;?gKte=GOoEOPG$lNJHd3_QP8czUHkKGyI<&}T)+Z&A{E&lAp@w!$# zWk>srd@F2qNP4A8MsbaH?Z#<0ol~d~@4q48vV%QG0V@U5&f}DIrbYSDKZYjCN$raR zQt2ri^upaqXE)VxGwU9asiUL_01G{H`U2<0AAcu=qnGU2?Dg(!l$z_XRV);sb$xFz z9A;3L89Jwv`<<_C*g>Dm1*yOijE+0qpz;v<<0q6QmVN8Iqy+0|>n*4^Z88>baaog< z$?oRfx}Sg7aX#FtJeB8}*ztWGuE@^BLSAa*_e0B_-aY)Kb83jJ9WJ2RR9m&q)=F~l z)zN5F=Enf+C(R3|E-wtdT$!UZ>YC!fSL(gixK<@+g~Q}c1k-qd3_Pi8rS?EoKE~`D zeLOqOHM;~64M})uehm9V$CP9H=mig3Gs6Esq41j|h9@ESAf?sRpEe_bk~SUKd^Fd; zFrqU@W^vm>7@GDy!wV3jvZ0^S#-4fJdhM#u)QJ-yST3)-jNK8X#C^lA#P3YMY?qo2 zi|&r?*l-$RWOJW}zOimP9o=*N*7r&8bUCvXTk|X{v%qmEJckI;p!Ol5VD@SIcG`D~y_ zLZZ5J)5(|P6tL~b8BgW&=_3>6~`<9K|29eug@kUd$ zJPQAT#^2WiPl2fs=IVmHU-3#pe=P{#uAuojK3A67WldWWoR+v~e&iszqBfP2>AbzE zeGEB{?OWj2HBQqf1jaqJmt5l$-2-?XI6azkyjn1dmW3n28Gq4&K)!TT%X&N&`HE6& zwr^Q6k7ybaU9Y(QifmaQA-MMrRG#R(qHFTCq8=v}GI~uG^}OKi>jyNGrk=cc4j_#S@6<_Ynj{xC+hZXbCoxt5=jjIkZOhzoW2cs$( z2ixM@3Q4Z>hQyoaPww5NA!ECvdp2pLI$zAA?g280A6&$mQVqp;)+FU&8`E^&dNYa} z8^{F~Qcc#poo>Cb!LKh37cXkVqy@m}!Dqc1PuTj2P&Z?F*mh+iZ90&As0yLD)?9-I1ki2AfTg+Y09N^}7pe@)9& zb3($b+`@X=(CubA)0AEk)ET^MESC8~reBT+-@IC0X%MAn=%Z$DDx?+#z`4R-jfXUo zBn!ZH22x2pyGYCC@3r5aTI~3spKJjd++uF|8b$%icaeDKtNEMRO6tAq_Rm+fiDaSm zc8xjHgkvC<-(1783MpfGgmKaELSKV-g;>nsM^+G}yY>KFM7I-*#e-<@_Vk)HmQbPP z0`%MtYN#rkOvBccqJzhDi*62aE+uWY8eVBIhVX%06)wD5?nPnNCV$nZkOmRxjQ2AN z1{&ALhJ>NlLeFg?j(vJ=#H6~n__*{EF(=9vmL0OfUL^=XxUmO9H+O)3-@ymMYtJs* z1+d0sL!3iT9dV-3gC~U0f_o}afm76uSUGT3LgUi;_{)o!IG4J$`F6zE{T6(r)m|hC zb8p_ZY|jj@Q<8(gy3*>aXrk8Q25AnW(3R*Btw&v9mB-t~Vhi7vAom#FIb#ELf+ub? zCvb)UiXisf$FhQ#Wr~wH1{DG+sI>K4p5MU&qDZK55IZ9Z^ggzPii2r!Q&&|WIw7L3 zYoAYhfsS@FdG0z}S@-P3hN$v0%GWDuE!NYli3&pS4Lb0vtD{N=Oy?>Y0{d-8M&j^Z zVln%TJTzqu1F;P5kI3DE`002HwFN)SP$@z#AZvh_g5k6rx3@!@xJtqcMr>P9zeU~(Ayue9Z z=#m0+RKVa@A1D#>K|o974ujBTSq|CL%1QJ4FaD|hijm^>8#W*?)a&nOh}guixI(`8 z<5ETb9}>q39-{tXg$tmImIS9tDAAaC_uTyz^$wn$5da}!6aTt8KPCyP3hT}blO|3I z65>w;K`6g`ec|o+O_d*SEL*gO+f10Q=^O~E@IV4;n50wc(jl_s=%*y9lk_hWt}oML zd-B%aL4!Hc^bK)VdfJ+}d>i}CugFPqF3Iv$e~MIu4jypAuaVH29XivF{iqahEH-OQ z6zCM$rQ_fZ<^5WsN^hb1f@!(CvnJ^aH26?cg z%imeWhuUJTtlz#M5^Qsqh@um4t#`^~{DB-uo=km${jsZNp|BxYwrrtS6@Gv4VSP zy1m1*t2VPx;j(!jt>Wlsp4Z+PH`ku=GEio;MoU-ud(mYCmzliI17KVsn1P-FM9H3g zFW*>&Li&$JxTo!MizMlA=O0?=h%qiZh1vV)hA9$fG~#;{4t6sfGzDSevreX&3UL2%XHI&%^@vOUig?=+pJf@OMftd~LonK9Lpa>%;g# zcA>#{mrKl6gvUbfAHedC4DBolr7Xu_T_1vTYI~$6MN}a*t|1_3WFe6xh8|V%Y}@d2$evEbKV8z39p!ip!QxLD*_+;!(6MOqt<64)VYJQT)LP ziI%C@J1_pAebV`P#>STYgb>3kcXJP_C}3A&us8B1$7ngz)hfHJ1j)yU4(X7SxMAEicirEMEOi^#q@VZD z7urf_@b`wtXB*H30Pg43{vmZw9^%(eTsJ8=w>YBok6xAxHo!C@C+_rPmqCK||HQKjD=hC~7M%sB8)M zn|L~1R~y5fbRLCE61UbcGw0j?k_M6(#Yd!i)R*KI*FH5EZZk`=Fzt>}iN7eob^tmj z$m6vA$>#8)NDU0{+EK@MgzW$4MB)D$JH+1s^KN;TYFZG(_ahR3nzEKsxdJlee*iSY Bc!~f3 diff --git a/doc/images/columbia-small.png b/doc/images/columbia-small.png new file mode 100644 index 0000000000000000000000000000000000000000..5674017fe0d8bb11b88fd3174dde50dfe49bc1cd GIT binary patch literal 1170 zcmV;D1a13?P)EDLvs2LEoWvS&7rvi$}W{n{cGf`bzs9KPN3^k z*VMQ6wMl+<{q#!;$%k2u4ZW-C9yQg%rOyrX*7wd-2$s~{tHJY3>he$TJ>=3SZ5_j# zBl8Es(XHXR{ekJd?(uCn5Z{{@9Wq#M{q(tu*B(j_2s41yzi~y8(EWWOiR^7S$vSNysS8ZTsAGgw~CLCfqT;qk-;jP!l zG)4@~?z1KC1^fNn6-DPt2#i&AttxJvzaH3M&s?V1q6&sM;!u3a2t-ly93wcRdXkwL zT{xUwHw`nPY%ITVmfmMz24wW22LZ3-E<^u@d%^Vm>MMmy%so^yemcjti3w4~Cq?Iq z!sal}I%FiS(S`)woJs23UgvL7sq=!bj3 z@UpRQ?PI+hmI#YqFtSNLs75e+Q+vokSl*ybMB7{H>E0>b+~-Ko6}QZxmrNr@^4`(N zV1olpj=wU_?;78>i)yD&L|90%#Ur=Cs6FIXch3Z8_p^epgHZ)5JAZH)F?^*)u7lNf zFK5x@f+jfpa!jq7AU=mF zd0s9F36?KlxJCXr@>Ys}5I8|ndCILf)zab9;O}c^A&Dt(sHVT9kb>B2%Bk>gy=o#1 k(EsoH|AYMxcJ!Y=0amNL)OYLE6951J07*qoM6N<$f*q+#P5=M^ literal 0 HcmV?d00001 diff --git a/doc/images/dataiku-small.png b/doc/images/dataiku-small.png new file mode 100644 index 0000000000000000000000000000000000000000..00b3ba33761986d327b40be27a109dd7f877fef0 GIT binary patch literal 6101 zcmV;`7b@t9P) zaB^>EX>4U6ba`-PAZ2)IW&i+q+O3&cawM@0ME`jTAECID$Ke{`8~FHs0ISqmT4RsH zQcPE?7A`;{Gf`vKzy7(+KllmGB@<#Usit)C6Kbfg@ughn=N#=7^izMreTl!noHy>d z;yC3w`txU6_x+9Y@%2Ey@3;Nyd6UOIPPqrVFFXT!nc4AW-iy2k^6s3s_urAcZ$OL0X#`Dg&<*mQ;)18?-tMJt~InO6&oSbodcfRl$ef}9g^AWl? z0j$^i_+8_^!@t6D270!YzQyT2Z{z#;X_P*%$h!L$_qw3t)N5a#e?G-;?^*U_5DW@>&al$Fc-#cS3g7a zXkXa!nrDAS!>!>iX`AOg03l-T$YQJ~16Q&>+GMZnT(*`t7V^_rx%0twGGL>LZn8=j zyiX*jIGLMgvekQCbmuhmLL#D1K9d>>h-8HjL!#aa5Nam+7-Ebm=2&8lE%_8uOey6g zJk_$#A;+9@&L!8}iZ7wWl1eV6)Y7V}9vX|LrkZQ1wYJ8jEi}ez+}}9SUH3io*i+BF z^xE41J|m1c(#WHXI@)y8Pbk35H1jO8&bI6VN-M0m(#or>y4nV7+wZXBPCM_i>uw)e zyR-T=Yu_{X=d9T~Yfcj-_vKgCIGyt|gcUl8k~1>qf+gcwG60~Ro$m5hB~2Tm z4P|N+%=fT-goia;Bk$2CAD3=3B`!)9+77nKQeV%QPY?gR`ekIs^M2VXLxLj$8uzMs zr@Wm{M&|T$voSV2k_Vp!jMQ0eovF$(kJ*p4ojSjsV4U^8lx?`y1HQg1zioHT%2 zPHXHm$TFGs6U{rJ#Awx>qC6DHs^61)b@APu4bM&K@+}UQSDX|IgA_NGqo68a4A52< zVw1cV{i`*3No2H#^2XH>- z_Ic;IQm(ZCM-pOX)wzX94FX2zPX6CT3)>(nG-F;Ln)ZGf|?CGxEc(ukn?<8ia~X@5TgRVgR~ zmv^pysN8y8pN>dqCkVS%o zgjRTi@M`!@F(`oh@MukOO|nwzVjFAd3QdP65Cy=oFnT1^=j7*r%q)5jo2Cl|CFr=2 zxzT{k$!X zWp=H7e*=D8hgv8lHXvd6otI&%HG3!8D2Q$ z1{RPzz!GrR8{9^am0Gt6szeATJSG4rYhp&Q8ZllxDA>M>!#~C0PH})|uszMCQuat1 z9`3~MiJ|a71Z%4ih-PKIAf?axb)+~=*mo)ZT7GFgAYb!5iFb9_2fcq3ld0MLftipf0HwBJ?J#1T^vm*LO`-98n6#L-|T+Oyaayc$IIVE zK{G{9HY?iNyqB5?tD;BA;4Q05V`M8}nYrNeb}lueN*aZFu@?|ysXdrAsYjqO>*`Ch z6iUDa5^UfLMiDYRK%1j=us}~{WzCMWZBh~7c0lQe%dzYG*!6nbb$s@Rd%f+t-mHdK zi?+S|uJ|-Cjc?Y?;%>EQcanZBP$E+Zl&b8uCIJy;jcIJ*#EFiF0)n6BbyruSdXj3b z@bFV~Cf6XnqDmfZ-Vgl}@3wh7lZr$e`mNKU&&lZmpXQ?%)7jfWt5}67OjExI{hyA; zZt3bVY#EO&<))bu5Fi`ZKA9RPQ(ZGP=}v%QPGZDBgm)3eH*5^YK$ion%;{;Z@fK$m z%HY6_3O{0ypLP^i$=yAG$cI)g=2b+pPrBkH%aA9&fnK#A;#!fXD?nn~Gs$2uBBX%jqa|52 zY4a`U(s=E5dhivF8%~xO{eJOugaOfkuoC&i_TqE3cNB|99_X6UFfec2SJ0RI0eY!^ z#ry)i+Of%t$PiPdpd1=HaEkFFd`OEmi&}27=J1IVl;A)0_b*C#6P;*Xl~%N4Ts;kg zyjZEwLIyQY6F7wNJ>LJZHh1^tpBBm;=BC69R^u*hooy>%#Q35J)*|89Kg*V@v=shs zDZ0{t8ajC=ESAp9*nh1el;|CAvPXKKZ&!?7%3TZqHq@@nu2JQ=zxRGiL{1Euwu)kNE@L%Y4AZNVw)cPzlEZPV|4c^o z0muZFTZ1h__0+YM5JeVu27?gJE}5-rSI7y9OeaAQq19u-uH>W~sYSD6#9ZxgTPMVa zi2s5t$l+Ja8G4wwq&a!hHv9%LeI=C;Uur=;0hZd_?G&Esm;!NCVhSR}b9BPaR_VXg z8=N-)UEe7lZ(wh4V19y~y52lm>z`YG`}Ik4TAy;QqAkPxgvCDtQKwz%s;1UP;*@LEz?bfU|D`A!gCYHJn?6^jr0%_vW0uXCwZ-$m>^+JFwivw!pUG0!S&^~pxT(?zRkRbGb zTeu(QrjZZ=kxkGI?i4UV*<~re)G@>hBWjAruKE!}1QG+@j6ATYS>4ZyewFZo!H*ZX z)C#o3dGqrEctWC?%)aUN;R<}p)qfPOA!)__?F+RxJ&%#8^Fyvr(R8*!WYz0ut%2e^Ln;O<7v7xV4AqA3a3gr-Rs>qVI=RPL)cHh7nbh z&)eLsq6M<;c>#Qo|Nfeh^}Z?7zdo98Cr(2s?cM1IY47fSu%CXgQ$avNa*Kx0pyK|K zHA;l&G!1WTt*kpW>262wSlD!8kvs94(kVekC(jRUnm4>zw`z@*|6M`o4*#s9^zven zmg9;;hzvTGR|A51BQG|E4Cn-tzn>T1z5IDGD(vU>S+TOf0ZRS$&a@lR?W5Jd-{kQ| zPg~qR+Twh-#a_DNZcgC7XFEX>4Tx z0C=2zkvmAkP!xv$rixN34i*$~$WWauh%X$q3Pq?8YK2xEOm6yuCJjl7i=*ILaPYBM zb#QUk)xlK|1Ro$Su1<dj$A?7vov}_x@ZxYR+OnKqQ`JhG`RT5KnK~2Iqa^2rJ4e z@j3CRNevP|a$WKGjdQ_efoDd{bZVYBLM#^ASZQNcG&SN$;;5?WlrLmFRyl8R*2-1Z z?32GRoYPm9xlT2N1QxLb2_h6!Q9>CuVzld|SV+-%+{ZuY`XzEHA!3aYRWE>w)S>K&2g~b+00%Q4gZ@Cih;-*e-4E;s zG82HpP2uq(%*4@2AcIjo10T0t6bb+91Z*mV=i3S8_W?g{H|<&AUSK2eOJERC z<@5OqM92c|4_K#>tph~Jw(Y=35vc@nRrv)F0~P>FhYufqHUbJZtP{(dz1npm6Ee3h z7mC zFG*JhwJp@L#sS%?y1u3+daBEO8aAx8c@41JnCh+V2$pMCWNLCpx4ehN@nGn+cHP+> zOfN0Cv9_cWbFtIgm-Buz9A_f1M3p_jbX9IwKqT)OhHij0jnl;3A3MX&dUXEJ0dg7}TRz>ktp?i;3;@KKXhUBH;1Do4jk&Ai!O-tg4HQJemn+Q= z0k?J$z(ZZU<39Jamj^Tfe|1lNSM_B8_NEbl?TXyf`f$;vbvfCV_Y*+f`+>Z4;4NJe z7ySTu-LlkiJ5IU3K1~4w+}O#Sqbl+vQC(XU3Ez^`f23|od2z}@xG8)CLRAU`RCmtX ztJeeTzt9?>5D|E!gYg!HyHNcjFrgg>yrd`+-q*pUN-AP?OyR9G?(^n1Hnxo3wymbQ z{difKU0_+(tH7Z|B5|w7DB=zJj8y*BTg0I+f;byGq|6|M&!0$O@t^S=r+>w=+~@hMX?L1GRY z6Wwo2^&VgiaEH(5AK8|r3G5Dc|YAFoj13QphA*QH_W03wSKxg!3= zsx^HQBRBl}0RV@{`4!F8? zpH^rV*rBjB9t`cxS-O~}haT-MBD34AuWgNqgdWSvlE;AMwjH=x)%gp^E?|j>wFqcY z6`{_JCFqo0GbaTJlt4ZTi_j1F%$TUd1-T#> bq%-*+_F&OfrrLua00000NkvXXu0mjfq79LQyFq zy-MgHy@Vog;(ovLp6~p~bv?75baC~(+w^AwvYLlyQdWq6C^GQ@-QN>C^Z6ToDt{f)DleS`8YB`koL zB*Qdi)umBim-jr1a$CCmbL^0Xi8=$A$wBvUg?sRkKBN`CO*{^*cYEpEno&@xirR3~ zadNqPFS}aP5K>4b7Kjcd+k)lw^s-z$Ixf)o^@yC=ZttG16VHp8kTD`4nL64qZHmf- zOkEkfS6Nrx7xVoQ`BBm_hKFsgjt<17ep9}NTW)!k?L%?4to z%R*NoEFDyJyx_$4qOiWhdK=DY`4pgxwc3ux6*N1ulqH^CCN|#jP0HbCnyov5-%42S z9LC(`p?TGs8)*veOP-I^QF<5!)N>03>)}{oYIVyvxzPQoWQH}jP_i~PIyy^jV`IeL z<=%Qx$aIJL%kXy-v4l5OJ(-lylh*GFskkk%g)_Ag$!n3QH=sVa>IZRJSwm zm>q05SK{kgfA2|;HI~ba3(Ny9;nX{pylRCRByz7zRjEP9}XHjYB| zgtyj^_^z@3%1LNz-W#!6Dpj|bE&SEcI-1&mJBcMMEKY8WZG{IGAYw^68v*8gz44be z9n`F;vvX(PsUJaoS1C#M8?av(o2iqo>Y(0-+uYPXe!1OOyM0~ape5jD=_qLJBb>7C z#)`b29r?$Ao;3beW11)F75;rqK3WxADYdDL9pq8n@vm>pmqlJHjG&7#gWRy_Pai5! zV(z24zDS*-Q55x8vb%O5+pLbE*aJ{2Wzl76i#$Vy_??%e;u+=z51ryjHvlggdok*0 zNtY4gINTobora9IJ=0{;>_DzC6D% z^UE)kw+v);dm_g(9-fbY;ruA9OTCo{qh4ggbztS1%xL>!#FHbLSL>0Aj!cealC`24 zGoo)oQ=GYd9?CXaGdgn!qq#S9`O6d?HM#p`~-{;NY{%5)Tb!Cyx29ERR;&&yX)le5gZ+-Uo7OqWBJ!qwb;3 z2+Z6zGMMDS8Nq$P-v#ux+-cp;=f27ET(F7m^p!NUOQ*ibu{yx*$xjcAQ>Kc8g;$~m zX9q;fEu36YBRz#$J@d;5+F+bAH0XIYU#cU+EH^AKOM3MYw065opR>b8{OhXABHocR zhqx`y_^WH{h3sTc%q?z}NDHrO@h2nvxV1F2bk)>D#}PCVW56RP$<{^FXoj7*h$dwR z`!E7L%M)u(Nbbt(vi_Jin-c+!QNefWzsaa+DT^^Y*bsmr<@#NOsA#|1i0G%w+yQ=C z3*NUGd%Zw6oM~>}c8{2?gPYEXB^Jf<2q9r!mifa4HtdI>R`*Ylc(QK_J+Y<_^n9Vj z?EmHqr7hEK!$;5dSj_XalN9zkUbgEK8{}|O`m|*xJDxUViFWcf9>)r*&!K29T*sCx z228CB0qSTtz2I&yQ-tK9tWS}-!&wi_6E)2W*!lCV$!{CTG`e)p=Aeas3%`nv^jMEy zyFNwmRTG!+YjSO)SS99wXcJ^J&#q=g#75 zhzaiTvMfWoe%LGQPioGivjaZTPeWV zta}GUwGU@C;hjt$_bn3R5agbTnA15X@gkMTHMcy+nykaCp}pq;@eadSz3gD}@$!V~ zb>~*nwH3gTUeGDAIAEPwV4t2&-VaW zm4d6HhjjylT;?n>Fm=G-NOfjj#Q2PWM zl6ZIzFkIH6?-Ya&{n&4jppDRZI(Nx1wxBV2dMUno@2v^w@a$lDU;Mjs(D~~v9+*XZ zi_@;wcx7-te$U)7&+=r14B~rmJ28JKJ~7KUf>$h(>SK!qJT8@zzR#SE8*~xPecp8w zd6JzwWU^hdQCAcCc0ItcIID6mJY$2y!pkF1ar?^KybfO-iFqkR^&uw)7QOUiF2d6o zVg`#>qfr>>sbdTn zj}guLkk1R{?enO`wJMUaYb(9(2F**d%H*}p%k|nf18|xUxn=dLZ?Odk8~AQ^kmJH$WSAs;E z@3qYv@rIC(t3Oj)c8=d#ExLeR#eP~;0e=vY(k|Y^)mrL~4nV8GL-{8@77??vsHh<4kwn{v5ccRc)G!06LHofFnol^41K!t0BNa)PiR>=TA^G1})v17UhOLX8Rf;%;TTcd~L=)mj1PT7H@j=fb} z>PK+h{O<3shUG_Y*E;X}G`{V3+VQ&@WD0ZYOd z6k#p+yZ{B65v8k=Rbc3xtKo#JYQ6v*bw8|!?d`EclHrWdI9pO=YMdp-SHM20F++0|T#6ZCTa4dYlx)tDC;j z3oaY3eYg?j5*68O5S2lr94Lu?J&IVSkM#DE&x|a}?NV!nJ(}D5 z#+gjv1JR7#o^VGnpulYJO;}|V-}4`8BcQZI9_TV(jQ9byo0mQJ(}wumNqoglsxuNx z@j7dy$Fh?3ahN!6Nn|AFIJvTG!gXKn#479l;I$@${KX#^_#F-&TxOXUl$G}vXT~NG zc_1}Uwnn#IkR+m^rMhPgRY}%2gU27FwxT$iCh4Vxhyrdw?bx6Q%k90Xg z#X}u^NZCM9sD8v#l!!L})uQngKd!8m1^~LeN|1*+G*XRMXP@aK+`?uTp@%x=ocfF( z`>ScF!N9RzjXB7D=jv;zf45b(&St%Bp#UbKaq|2-kbj_PLOX=Fs;+6EM;%Vmc%N2g zm;iNl^w2!Z8aljy5=h-#93*Pu(cndVWmHU3YYIu%Ft(+W`|uqstMPbD^SA?L>30Cc ztwiXZJU?J-Gz`8WL`;2SV}K1o24?UBZ03ac@`2i79aNsrgs9S(18AXk9U2JnN(bky z`%uQEkOu?m1C_D*=KZj~_TzQ)4as5L*=UbgMiY zUsWs$TDqe-0d=9_-5-v{X}|kye%Q8azEp0>xh8MM13$hez@XEe1gEOcsyB>ONahF$ z143b_qI4CjpFerUb7GaCF7hIrH{lawu@cldD8y78@K^fDpbI}GSt^je&+5nTc0i$c ziP$`y^eBzxL6`;g$f%=lX%7)gzqTndR)v6`wuiB?z9z!PQz4|r?+Bc`l@`3@eoVs9 zO|`q`Sw-C~o`cb(Q6z;|ubo}Do`W(IS+|D?TtPd9Ls29!Hc|Io;s#a^r?kR_rTc=v ziW(f}I*wDP75**!;q2VfiOXiBVSH-#>1&ib5mt|f%H>~7T>Zv>lhQs!obXnQtrtP^ z5$%@HK6vN-FY4;EXdx&?KVjbM|7!b07alvYE3Gqq<8Pbr)+WOaN8_ z%yN&j?47h$+o7kdF-s!WDgc}GjX{q&12N-Y?{KND#y3i;b#zNOIh3KyX@vSyfXxV@ zs&!p^{plxGz=s32VOIfvge)b3o#}h&lio_Oj%i6G4EdrFrEl;Oej8TC?LwIObv1*S zzcNJPv^v>8R#*oFcYwMslqklUyx4PoEey?Ped#h_ye;a*NQfyUTf@?uOZ?v{VomP1 zv59mt!tA*sq&VNC3IF5S3J!R#>wiX35h#(GN5QnBJ}j^bt*4MPp(@8l4BJw5Ujv$7 zZ6S~@7EhPHX$gm^+px%eIQR@nGBq|zfHshec0^_S&w4jUkj?T~=Mt8KTag{? z>!~XWz8<)ZJ=K{IvL^SJ{FCu3E5LjP4R)1byM$WhhFLoP(GDV31zehpTgprhHPuIA z0~8jcN`2}fUNoO<_yP5D>^*--$vt6JMZvOx9t$o4&wrXLg6XY@l94y5How2o9T#WG zOfZvx8u`nY^V+yAK0;yx`j0;3;xJq#tf$vN;~DYMos$>vJ?Au%6ZMbKSkB$30s{Pw zArVs6c9S>O<9^#U&5>Zs!D*QR?&&;>U~!(^+?LslqWtuC18Nl$j0*b5B;sren++D+MxfOb3iN z9;EZv8E2U+p;q5vt?izQrbR1t&v_7Jim5=kho>q1&VMzDoSr(om)M<85&Ic{yw4t4 z1)H>=JoqL;68(P{7EDcgwC~{$+m&gI{_Y;~g^c6ES8;ed9AK}~7T*11`LSQBQG)2( z4gWSrO$p>m$aV9~PGHt&rtdplhd)q5J9@w)`!kU&&G~ESytw9ptUot_Kiqe&HCj>v zo{&(PDd&86Ylf$OB4+P=9rr4%J1J}BpT8r89d8eKzT4T^+0T7sj541};{qGz1uG?_ zS>~^^uZeL_Pq(r+4OaQ!?S9QQH+OcY%&d-mV2f<@N)8+D0qnfcgG7YpWIb8rj;y}1 zPtAVIyBpl~`b%lcSB;pUC05E}UTI7@_SyN{y}dj3&6l?GVyCx6=V6uRt*w76BepN7 zM8RzM)&dp5 z^|W#}W&nd87&P%m&*-(*UMhOsFWcyDMax!AWyVj94Huu9dI@ueWg(25k<;Qq_W_Q2 z)!Svcgo}&_Ei_Zx6bryp8wd(4n+>N=BUcKKia)mH$B;0ci+eYv51pQ# z>M7P;-eGwq)WmAmNAOrkIi|O`Ds|A!Z?rk;oo3ZpWg*YdUj+!8XrM76*x>EL&U$H( z#emLXbeBFx;*Is{8m?`AG9mKKwPv00;@b0JMOJ|U>LMHDHEW24^ECOzGqs&W`cAj# z_G$)!TcuM7=EX($?>Z7&)8)&^6-{>V;$Ao~Lk|}0@@2se)B76|dKZvT!2A6)%tTJ` zXl=CWC$}}3l=c?d4oet0TWwZbVtur z0rTUg^T|EMv5YPW3wFkA_O-Lj9~5cF^3y~8Jd!R`&>Bilo{JIJ!(Y$%j^7^Kd_Q(m z+gNrL*HWGF3TP~}leq2=8BPiJ;`Kn^A1TXK9se=hsRCjQf57(L-u=|YlJI;2il;VK=V_!?3=Ib~@XRJF{ zwtg`pN^h1Rn%dIMoH~`sLa&z~+DqTq8eg7Fjb`GVJlTD0OSqjt9ITB}UOdx+htZ7Z z9Y^J5@hP--hbt!7XXX?y)IA>-b_ol+oHYS9T<3YeB|2Dh24L_CQPDH+BN5(}x9ZoN z6D2gg#VR`KZ6#mg(!NP&{b4gs)Hxf5F_yeCHRvFY6O*#q+1W6Ppl(THpl&^%6d~>6 z1@h@<|1P?=o)^zp_o2ih32Fis>}Ga$mDLY)v~mHuW3Tuf5_+t#YihV%ORe^u8^!c^ z-n2Jw7xLiFmCmdhUiv}B(!4N}8xq2n5k_()V7Xrpmiv+%!{cZ4;N@}R&bC6vB9)lI zQmNgjJA9j}CVmY;oZ^~=b_&y-6epHVIiG&wgUN-}?M9;D2$KYFczoh6$Ry0(D~0>i ze&&~injO0NPiFAZpIcL_3w>lNc`&^P67=dBOtVo=3Z#4+s;T#PGp;=? z$DT0&?HvQ$h^vc#NSs45ch{Sn6~Ko1#||36=AV7DJ41U~!R{@P^CH^(pbc+-Z|{Jg zePkoqsf$N}O3`1S?n5o$){y0?dY8B}ClLWd1f=d#doG;nu<{Uca}Y0K&nG<`6+7$C zy?E$*ykYV?dT6Z$A15G(Z5i3$)vW|~0xuOz)XK(_WEO%^gr1Og z0Pn6CBW#Jc%d*#;K1EvHygX5fnvYH8Hm>Lf=*bEG*h+&l++gi5c*s38YzC=Mi)0$# zrc}s6W9EiE4M-hVD=mB0=u?U*4`@7Jw!E)Pyxaj5x*O#~&ef>6#InP46JxFoUpVtZ z?H;YmVc1wW533~S?RcG0`6{r~B54r3CG0nwXBO{qWw1Z5C|{PrJzZ(QTitnC-w>oZ3YS?)Eevyj=Un%1aH-`P`n=S)@6!{&zH+9w8uF zTIL-GtiHrZ->rv67EB35qSi0OCIcL?*846@%3LutNTQW z?SdLHW{UX>x*mqwpzmi@jMRMuHR*4*Me--UOp}u~p1*P8KwwQp1(+vI1_ykq_O?ao z9&}kG#JSa*e{GAB(s?TRTR8vA7Y~wG867DCPzWTk)Fy9WKQcXa`|tdoFN9{MQSrgJ zD(mtmP)N2gS?AA$kDWVv@A)zf%X80P6^RJ{-1lusLNp9{8eBYEU8;LMKfi7JX@B_p z)K*!qz#5OdFsU~sZ+q)67h}^2lO=HHYRrPsM>$q|X&}?*Hh8%`V5Cf?A@{3N@7udw zFMuBmXPBsN_^D+;OKZ1rd)n}$3v#1&#y0MzCsCp?G%(N z7ie=J*3YTyMP&m01hWZfYMs5`w#ym^_?~zC`2B(ql8UUeIS%Z6R}}=PB+w#A@^2#= zc$W;krjqdGM)S5*&{Dk4dpzBB&id3@MJMZp^!R{*>LzfZ+GT7`i*Z zgni8V-zHDsVtqqY2)`A!X9+uGWS-|}7q2Fc!J*7Qydx?$N`pPkKbI5)Hv{ki9G zsEVL&v(nM@67zz`=HD)CG9|Qp6uhJeCnSa2f45fz_=Vmasc`p~*pc}YA)bC|=u=Z1 z)!Md~91790)NrWX6;AZkg4YC!0zjL*tNL3xQunTH`EoC-sNSw z#`neoSRuFdNC}SGtgflG_~6wA%F^-_r5IwVNpJK-hJ000CXNklA>kpTFU$f)FfhBJyZ#%Mf^)+YhbkpSQl zMv&?-f^%Cc@N!0x>M)A^2$sq0y@?52Kx%-|95FWVS31)rd!P+~GiwPPN#c1nV*{Qe zW&BAnXJ#8rj4>@KV|QkxgaT8sC@+v0$aIaplc_Cm9SXz*&L(9nNoDSgrHbIe%ziMD zLxF2a?9tqTD-#*zbS43?FOj?UunMJ&%xrTKwW#1o=56l43L=L$NC2Ef@Lpy+luA;@ z!9;zuX5jO+JPo`pv)xWv9j|DnrT7CmKC^v6@DeG95xmY>Y)Eyn25z0%IuOi6s;9x+ z!o0;6xJG8{PK>b~!$|csh_#C;aPiF6pP0amyhsmHJ@jSfVhNnBfv;p%<~n|05GjZA zizTok_xGqI>QKu4xyMVsE{;I>iSM|B{aJ!&fctV+(_#wDS%MdcdRU4CK(AsCT#Lkc zL>)FG0q{w&2Tns8;|dA{Ip322II*Y$Gq<#R8m2ClUaisV@3J%K3xT00UWq zDE>ew;~Ykj>hKD4lYL)w%8bZsz_C!3Ili1|fLYm*TX>Hj_?*W$oTVrbxxoK7Eg$&& zad)pwHPKck{HBz}2r5~I7-L3CnS~P8VhZMBEvi_9*;s+KSez@EhG12eAy}Fuj000vMdQ@0+Qek%> zaB^>EX>4U6ba`-PAZ2)IW&i+q+TEI4awIuyMgO@9FM$L|kX#O*5#E89?}NPNnR+W8zmc)t92L;W0|=C99(x<2P=pA&sv_;y0|D+>?yXOW*1sl3mp z&)=cT`(W5l+xh&vOS(T*=KIp`f9FaucHw*~L5?Pk&rf-;q^v@_4t^$nMlIsc)aP;I zr)d3a{lfRseNn)EzOG+fpCj}~Io=a}+m`)Sr_b}Ld|kh%N$LH@xpt8^r_$% z^}D;>d*0oRY=LswXH(y{_&nl3N}Ba~oAF)wS9n~X@6LCNBQCnO+3E~&@<&fX#c#@K zr;%^tyKF>9+(q%nO$^MvO0iN) zE4>U;%T~;+n)B4CT5~PcYOSsII-0iJO0!m5YrTywU7@iOy7k&y?|lp|I`ZILgX0HJ zOqn|KEYoJ4ZT2}9@L744Wvi~X`WjodZYUsZ+jY0y_c-YQrISxNcIs)TpK-z3)thg* zcI$1o-|@X`pRWGewck7U_gzb$uF2CVKd%1h8c)~y?G(xBq_o0@eMeeUq;`7rb@<7o7{wxwLKoX)u%x|h13*=RK|~khBW>|PX68R5dW7a_;)7)PSJ8-fC@$R0@^4H^J)e| z+O^%&sx+<_rv5gwPNVDA(vd#{60PiZQptJEF7@gt71JPj&IB-M%?v|NH64?GCm!oN zul{eg7yfZQ#syR~k!@VDvv)t2O$Y00XSF)_Vp}wCv8{Y+x=MT3&Iu4^J&X5nAnMVV z?NW43UF|I0190^bt~Yc$Ym(uJXMjh%p7=BP8xc4sMlRkgjRZ_KLh)kc|Hrv>)T$?0pod1x2eW#ySl!M>a_ zPoJfoT9?)__Fnf|ztf&M z2&d}%2p0SCJN*8{pMRr5d$IP?G}`TTLvSPSv+{<9g50xn(>tfo3}rd(FVFO{Q#tDZ z?>Tu>=B@ChVn^=n7#AFCI2gKdOQ{zy-L_N3ey|w>1YRkb42MD61LXIMVH_)QihN;2 z+{m`J$~cCrEj!M9U9|I{By zvfee;E^UuC!5|E2f{n3^j5L##_L=}k53`?Wkim%HwCDyL@$4`wZkLBE`U&%3gwM^K zvdDo2WHVa}SJyc>!n<-P4v37!8F)qxtVLT712%1Iky7T7p4eYwJ7 z)F;(qn4bgXjsUO#bu?_)g*YKzZKpx^bAT4sIP@NbggU|o*f-o;WBL^Nqq`azkb9Ul z81!OkE+TQ)fuh^02vTKI%7kGct}$ty_``-u0hP+aDTf%`6=yrz+YRF)#ceuUl*=r8 zOUi*05a&wgIj>1S%T>Ky8_w*&x^1SMlv7ds2KPzWo$_xzU;y4bOt)Jx3{-R<2O1ge zak*@t#4+{=1&-oVR(9^kGK6XdvB54t*+mryzzBpv6!%Pq_cp3PLl1HT{SNBX5LpbF z!61i_ov1JXY5WMN229%722rq95{37<$##*hYcx7l${J@C{zgF7l}8*QDw=u01nsaJ zE_~SiR29RwZg;k=Pz!L17?*yNcO_e$=77NLdT9<^crgqP4o#@@3|9P)>6EmfE`uBC zm!Y0b)*KJQ&xWk_ruBz0n{!Jg57U8xTc8$m#1=9Py73q04NW)0$4>og;~@$>4KV?k zdDm)9mAyi8XD#=pgHgZ2E7<;EX^{QE?qaxWzsQR z_nIX_Y_xoC($Qvj0{=jWNC##dAVo;wLagyNP{UknK_msI-4DuS6&mRYH5yJ*gL%Uq za~PX$*H?9PF_KT!T(Qubej~&S74=-fcKWPc01{?mCVJ+}B(8(41M^G9pl-xNKs8N_ zI2RCh@e|BiS+h2bN%9#5!RE_iBf8FZ8*dtHLD^?vlIAAF;^DD<{rKMt_b(Wkh1>wh z2_R88xSzR*uvTZA8_LKlXA49L)cjoUP0yFCyKfQ{0Gfau*Io4%Wu+%K)kfFBRTI>nxlN zw+J=DKs`nWF+-JsbKD~>Knb-+Ii5lT$2k?P5o3prKdPxj*k;=vbl-fw5 z8E3FC!;JE72puYr1w+9gbV5a)AvMh8Fq=2gNpu%LGUat)C{J>S$tYpeSe^?eWKONf z9;iOef@}}!%m9XTvUBNZTX5xMO&|;74p!={nVI6liIJ?i2I1JK3)P8A0H_eyP-t*% zAwrZ$7N{tu%hmO0hy;MrS3yvmGkqy*LnHXvrx3U^!E6Ap_^`)FAr0)Ox@FhR%o8jr z0pu{r45%qFi|%6p7(nb)64bF^EEt(p!2cQ57j7evV{uQ)5Cak<`eMa7>DlEl#M3=$ zg0mrmeX^4p{cF6&nwT+umW9_>4bJUY0Z3_L#%ctlgT&;5g5k`u;~A-LzwXshA_Uc~ zck!F6pWlzLhCeAQefbD4R_1d6wRjKtLV82NEcSuy-iqmwL9FiyQ9(j4DNEP&ac7gh zjRu?w4Bhc=j~n)k^0<-`ii8S>>r4euhKES9QP3kV-KB{0G+G?29QAl@A`kXjc zsDpTD2Q!KSW!#{lb(k&#uT&iyROt{^r<&czg15m_UH}{&F>{hLqIb@oR|7AXpI*jl z62WnW#4@Ox#o$#J77k;v;=_+3*u(9QY_qDYEPkCZKh|zFr)>5(pC(h3>A4c&jL6G+ zDlt(H9Kx|jBr&4EQgfLuAi=XIT#cJ{WgzQbHz$5VzeVcPQh$5Pm+xp#IR5b_Z0T~D zy{$d4$k-D?F@CD`XDp;SHpvF~08FXYh9(?d3)J#?mJtW+E*t|nv6(Y(GQ%( zFkh0FYQV<#B28&Mp-7P@P2B?O0VQdPNlD8~OzsJOCXc%-pO7&g?xo z&e+EVb$Q1q5Ud%Jst$Pv+I+kRW@$g~`Sl$m{^9VUAyo{95i=P2{ZR%t8Wj6WwfK`* z1k{3M6A&=c1iL<3#$WK)pf?0yK&PLoztt zLuTNVm?U=LDrt+{#;;vS`ek;=g3#6}JhlqeBna#hPR|MsruhEbCaRJ3Nwre39Rms} z1T$KN&*3D3X5rcT73@lF%KV`MbqKn~37>+IF8u5Qb3&H0_j16hcm|qMH8+EHi;7CIJU|ickQPh(WYps5# zq~R}n4%lf7G{M9_lS}Xbkdne!6X5)=4Wm#oZrN|-`OOD_tbSVsOCG^=ihjG}))E>s zDb4Q`svDVT0`2 z!xl!AL3M*oQaeu=a?P$wkO}t~X%{3gcbqEYJVRc72ZXHbpIoedRQtgNLu-)5)I%gi ziOyvAs+k`9-QW(06D4Uv62Y{ly@LQCzkx+}*E>bP_?w-`sSZLh>{j1b$BB>80jC(z z?_rHHg(s|i8P%KxDthALg`Vi&3s?A~azW0hDjg>PXB1TkLTimK6c9r}Vy3Bp*9kFe z4*N#4=9ZLMB<>XY2NW*KgdmZZ0A2UAas3`<8i*KX>)>6C40Hw)$mwql1^pw+Q!cCy zFe(qf^ycuGFU@sWj)Var19GBuA%l^tx!Cg1RDE6+E9%%QBHH#F)(cCJhTl#`F;E*S zG01PtXpeLv4JRSzc9+WI)@#0Wl5T(!fKW$qNJ;(~yj{+ZTzKoFX~&Tu)z zs7_`9%>B?_H*YJDV{ymU>BLc%dN80~=}LE>chVmJbe%n4fCi@!bI4r=B@sy=l$-(S zCg^gN1+yW#KEmN;v;woMEvETLqt5g9yVyn5!W zNBn;gIG(3?%|A^^DUOg3XN<}w^pUS|1>}B8kTDpwQVT0Zu%Ap9s>$d=V`m&xTlT&pn{z$^LX6@=l6BI_cbm zAW#De;5AR?4}EV(sLFA(6S?;C`;ZJFEaabxEwaQ+NIqMIYVi}W9(5ihd8PnARW|3| zWqWp*vS3MYcrCldUd8+_>hN2dYQPA_>R0OsCo{kh*ktI~05*<KN7cO9j#`%v&2pW8m5eS`rVcauvY7FXR3(x&OwrFE|w62sf z8w2h(tFclnDRK;h&(IZ*;dd7MI1IY%ZG20dHG7Nv&TGV;fH3u>NSgPLmDwGddI-vf zFe0KCTZ0RGKJJUiT~lpX8+*xwf`R{a*gcJ=YIpx%}~<$#t*6h00D(*LqkwWLqi~Na&Km7Y-Iod zc$|HaJxIeq9K~N#Yeg#x77=mCP@OD@ia2T&iclfc3avVrT>2q2X-HCB90k{cgCC1k z2N!2u9b5%L@B_rf)k)DsO8j3^Xc6ndaX;SOd)&PPglds#R@)e$>9(1S$Ha7IMGU^8 z3jy?^9sM%1j9E!apyl|whmWs!QJ&TM+@GUc%~}lbiNrI^Fm2*>;;BvB;Ji;9W(8R# zJ|`YE>4LzBx-kZS{s9P=nagY5dj z|KNAGR&HY4OA5t+z>DL230Yqi@OpeYZg8s@Ge6AEysMin>bN00)P_NS?CSTD-fX zxwn7M)cgAZ{g-lhs3Qt000006VoOIv0RI600RN!9r;`8x010qNS#tmYE+YT{E+YYW zr9XB6000McNliru;|m242p%%}7PJ5W1jk84K~!ko?U-9^R8xH8WSZ{ z6C{eXl-_2b)9G|BJu^M$Y#(N9(@+{p+d;yt$F+CX-v8R`@~62P3Lh`RA4VHL$=!o*Ft9-k#llz&0lbL=Vf_`J4PZ9Ak^4M?5DFhD=RY!bW%e1 zTHR4>GiTzOzkq?>UevZW^!_->jt*qEn<-0{V6xioexO=3ivQpz7?D1NVK9C38z>&n z9c?c+1)A*aBpK|!u^;8-VZ`Hz0S%MIf;}CA1HmB4t}Xzq#l`562wGns0QS;S)Rq>6 zVPLY`Q9Su1+S^CGknD0}n=ym_)>aa2SJC7BSPFgQ%$-M~wG}-U0|^|nW+R!+=&=}T zOA7#&!a|g(Q^(dvC7YGb&95U@%tk3IAJp6Ke*j4sg!X>Gw3^x>LvA-M)lVXqRZw>5 zGqhNgj*aVaZ`gu!&KzR@Tq5-0dlVf!M5v*G{tFiXQ0tFl^7^nWs=^3%} z2WrWE!~=sF#nUQOGan{tfP$@MS4AbIu-6{pUOoiLblT>#{+Ttz|k z>Os98j}ty~I#tOW&LPK`%}ByP&i8;MA(@r5!))ojj?~2*4)Pvaig#5t0DspX$Fiu3 z89U!b7zUB1^Yk64#j|TSo<|=`_4<{oK#0K>WwH&=Z&O+?Dayn#BiZ2qq>;(TE!t4> z^GA*%lL;hKI)q(AdRdpsuy+26I$s zTKYw$Q6F}^;+nqz+mSj#C%-3LTSI%pk4&ri7|Ts6>Vz{R$!s1yCx!UW2jhR){Cs@R zuc74guQ3()xVCT4*o(naTtU|wh36zwh^@KWX zkD;oV3a3vzP-o>VOkN*7N4`RfL;wiZp8!e1wS4)w2ZH|Q<{{qv+pnng$8m4ig0;Yh zb?L)Izd20*)vEwR{`j3_<7slAeeU*hRP9JHm&C?MA=8Cy|tihTtFDOPp)5_&X>oRfpCsEFp3Par7@Lf4TT zPVzQy#ryO#02J-mN#M|@w5?qOng+6rd(&15((98pcqe;$gzEPTV_;yIUiJIMU&|JY zK;6+X+cXjhq5AzPyP7uQ9xWahs^2ekRlOyfHnVh6TJ5%IVWec zzn$6Hk!NN$k&5z?s7ORe5D*Zk(o$l|5D<_D;144L9C!q{R6zy-fvw}Es_CL^ zWN&6^V@l!T>0nA>>S1XH0pYQ_k)h>w-ug=%#1=&a^@@xFf1%A8OyqYVI%r)b$2a#; ze_w>aCWX;DVcC5&2?AcZg8DRWc-nL;7xXKRzC5k(ZGXSJT6!b$@$(ZhDDl}0HHdF( z6||?WiCnt(Q=EDuf*Tr(M`rYon=**!)49rA7KHW{eC9Ho&8D`0G*8M6RQ8DvpW3P( zdYUK|Vmi8W);Q-?J$#$F@;9d@Xk*onV$Q)y^7k8hl8@9X%7N>G?(+StBzexW`$DSn zG4%XrioX-Z5Y&J2^!#G)_+6;?2JPNf>!F3)H!l3v^*x6k$X|%jWzPKZwGdYj^PS4i zy~TT~?QdcQweCYhslBYvrXv;Q#L1t^TibV>&9%FY7dyXa@kK7#>`8@>V@j+cVy9)J z5aP+Zzr8=AcycV-2U~x5#Wk7K4L#7%p7E_)#J!)-5BF0!NA}&k>GnErRlhkl3}(A; zj}s8Ykbqt<{DiLEyZi}{^?!LHAbl!Bw+||Wn&e340Uk19VO`P{A4J+aTt$O`vXS3S zYWa=q7#*Hp(AZ?3m2^%WVAgUW5ip;a<-$%UdzKPpN_%u-@Sq%R_RU$H zqxj{&Vnls$WFJdYDW@qB-^jFPbX8(%kOtDvtqByO@ zbH}t#@{8wS{3%TrswK@-$;nM|aM)0gGKIMF=u-0VgVjFn_+z1cpyumQ@?VhHVV3ff z%Hao_y@82sJiOi9PYPE`HdspRiy>)ht`A{4|Eq-;f9C*y^E|dEhRT{p;z|Xqgm*V* z)!@`pkbG z9nfOMn}$MB*bzxp^4AaS(GM{Q-F^<;D%w-^@_^eyRd1kXLKPgr)7c^J%m8DfxT4{Ti z5=;RYrewO=fqY)#PxE}}J?wQV(KttsM_IE+GLh`+(agbnIYa@2EbU1)N66$j6cm{A zs7Y_`*_AzBlEjqbRubiXVp;bU6B>bK1BuetYU2mzi5DcpG6gL(LrzDD$O z$EyG3C_TnuZ25aP9vu`DI%@3X=5tD0hBxva`8hn{OisJ;U__gqx?zpob+Vv*QnM2i zp0DulA8JDABKl`IUzkPPFJ@kWf2<0oU_XOY=8&GcQ{=)apfEN>_@O_jmy;XBd2=uE z;5BD|K|?A1N#si;OJpz|_xS{rF-|2SV`a?$)-jkIp`?SoOd;R|hh56ZY3`LQ+*_?H%*UmOeeIX}wGiv+Ap*`RINzZsb zDyHnAS=+rVRs=3DDRx~I#+UjtmnIiP#cYp?HivSZu!lt)87t8>3Uv&tDLvBVV~8+S z%hP?MGdVI9s+)#!sh9i2lxLtf*=|=>zC2Z7ws?^x;^_~}-$8-FVR=pg4L>~y5QG7-giP9zTxl6YPz{Kp9eE~3Fk5ESES3&L%~0zkF1KPj6x*%6p%80CWND1q?OcD=_HYLTL62}}5wV;8+aN+Q)-yq-L|gUMidrVMJ7`Z# zcr}J>WNoS=P3=B59bMgriy3jZ^Xn9~67z~Q5!+oX*_X)OLSzKF&(lV_)W74BkT|;N zswx_hH_V(K4^udVf1=A#H%?MWd`p$}IC4&)5RZtFhvT^vr=)5>bUx|F!8=ZLwfo?RsVm0?^r%tY@(QK-Cfzt=&k zv6)f~z81%fKD)|u!jQJu2%gHR#YT?W23;q@uV^J#EJ;|?jp{PW|+#`D*l#a&&E zVI>tDE2OfBEuD$N5XSQ~0%e9-ZzJs9Z*Tpu{Hs3(a60f4i`oMD`J|HrKGQiN_vwhv zEhYS8b2PIKd%PTZZeE`~+C-(dN95pm)IRctX0qf+<3-~v!i($smRfjgDHsDX%|0V$ zF33zS@oNm4_2P5?s#Z|k7iXNE<52l^UZ~iTHbiwS&5g4P!{Mw$fO{z!!qiP;;JzLdL4P;) z4rF29;WN7-6P4V^7O4V0H7EcTj)N{TpofuF=K4%RSUxI$M7QkyNk9OxbXruYKhTv^ zK26;z^=Hb1QVf2_^p`uh%Of~dzQVb%amh)=+7=AlFoNr` zSW<18FIaimqoKMGQM33&{xtJ!b)w>AVRDxB^=N90hBfG;!Vtnd6u#LI z4X;%c^#dXfs;Fu#6{7Y!2-84e;PtaD%1;KADvrHGXe)|EW_+E0?&beA&J+aI4qninYct8Iq?! zK@B}4mG}dXr6K`Cjzqn=7V?Q2F?x}qC`aDEU~c{gw9R3QB#U4=%p^HPfT-CL0I8Gt z^_1vo2>-@=wxZR{fFm42H^6)qON+LnDz;u&O9?5p(i0wSGRI*)@a z7s6`Eios-d1;C`?P-s?f3FP=^j>d-d6+T^zD2L9DC{zzl)?qIRl<4fi}flGejwKy;Mi*=$6?O=&X5! z7%+whYW(7X84j>-FzSy=B*9_2mk;sDZSbuyy zhSmeo&ZyZd+IF%K?pzvk@Q;*zXy?~mCSctY@k*~lqHJI?)0H^7pBX33D^cF}ZS6|O z`Y&fRdGCkE>kcW*pXjh4E*DrDpRiym#C!=~6Yh<~cdl102BH7T!kfo`u?S6Ja^t10 zL*w6F=AaS}l9n7E;-l(7A+Edc=#d+;D3vIu?O=el&4~jo(@zN$H#^X)w-XJ(wse?7 z1s_~QFUTbWarW0juFL1`>>mPggWOMkw(}Y@7+^ zz{q~5E}B9#@I&jD4Q%;sb)}3*tabOpK(YnUF)#0zK|i~I*ICqcStk_$&Fy8dZ4IEFa-8? zi|VSH^pf1A^z=ERdWN90P}zJ9)=p1pT2M?T)ixY(zhlwA4|BWNy^C^|RD1hEojB89 zjnXQoe2%wRgjz!Fsp`w(v4@!;o9ae=NzmwbUpoU9g1m^e6UwvBr}T@?Dz{Hoey7*!77-CZ ztJ+!y8&DJ^;+88p?ec3uc-#Uq%b@0-`rRLAw0`k=25>Rj^zev0g^W!3s}*#QjA*itS)~ZMqn$e%I5iEq{7Ad&l4IN?XtZuA2bn zl(V0!4&deBSUM-&IB$cafi?9Oex%}wG;91d4sdb;@lNDFo9tEos_z0i5#OqwfmVfr zB7V8B^2ao#Ag71iew^w!(^&kKz-SJq9UnfCGjSUL4Z+|dGK$*A4tUQynneEUABb`e zGYVj)hua_DxCVCH&{n&iSYB`EB6i(q#Vg7~mq}uh@)KJ`CVQ$eH*`B3uJl}&b-?x? zbAc{-Ixnw(>+}r0p4$rl!6A|`vV64AaCOh;%=;^RwM0fq=?)9+GIFz35^CrUNog`T!#shgh}F+8oaJXz+)UUsJYn8$%8${Ovt4bjBRxvR z&y-5<0H5*0zs$54!ag{!biH7RpT0;Xvl5_~q@oGE6Zj*YEIQsgI6LT3PSaXsL0O=x zsQ&m%=$6~H^co;7Fk~lkHBnoQ50vd$BJ1=y6hkB-T1}M2-^*R>Wo|}=G#i_YmdvkJ zNFrt1>>5pgVI-V@VFiHblAKqYE`@=k1+#Ws$ENiP_XAp<&QWXjF`P`x(jDui(@c!n z@+i83xGG&zAV%%MpEoU4uY~$DYLWQTV>V<;mSXYavY7!ZtJu%l71_3^fJ}9++shzN zV%V-^TM7KSX#pnN;hITwm+H3>8tr)Y`lRm%NBib;{(Uooy!@6}Kj+B?Q6cgiv&w%Z zO_+Kz>RiC`6JwXR4rOqVgx3GS-`h6H$Po~BkoA@g5qOM4r~TTKZ(+d!Xtt7-if7Q2 zi*OJ*@q+UX8L*a`1K?r5G|7o1dU%LU%kjwJ+_`uzl;SWWNvt(X>%>G2$30tXNAHcE2&RIR*eH} zXaU(bTfh&beC7*bYVF1)YQ8Xrar)q(U^*F#l{~A_(pt>IcX*ieoD!J1Ex(V&K7CvF zY>WpXjK{K#5{S2l(9}h^K`ahRFKJVOWT-E&+*MY&=1HC6+|qmiN39v?4~*>w%#e!a zl=>VoGoS&)g`P@mgaT1l*8!@@Snb>5pJOR}r5z4_-EVq;(0*0tPjIE65N+HprG4~c zv5VPlNk#k!>d}y&I=&}N9SXB#W+cNIfr@skLXv+eH*!nlR#${0YWT@H!+ZEgQYBYk zY792WeXlhPKyC0)mP#s@Ajp@s3|7`0TxoF4E!0v}R8d+~^#7IFz{RyJ-$VhaJ^;yp zfr_>G7bHAP$6>{6#)v3xE8HkW%6X(J-Zb4;N)AgYWnB{3rJ!K)Xf=3zb3IaBSmA5v z{(%MQ8F86+17!=$vmQsG%?r(=fC0cHSS6|3$N?Bg*p@3DA$qw?BlGqyRN%s-`BR=6-dlrwP*$D$T&}e`r~yq z`2F^>+_kpTjo5!>QKMl;zkE{9NW+zj^+@6_u;M*Hwv1Ps)X8f4k-l&>Vn8r@Iwq@I z!V|)cSBrp1g_B$tNM$2txgcGaPz2OO2IZZl4*yQDhK?b*qZbZ7f9(NnySJ#@{aBYp z8VKl)-t$3$?$OIrcg(sn9xBU*_%Pp=v456TGkpDdFC!XF0x7W8J>V*;@)$5iYPda; z@O|aiJ4dK4nuk!K{V0TQ3*e}HNynGG1RW6dS7sdoKD=8hr+!Jhy{+(1i~rR8+Z!Y> z)yboj8tF5*{0a9%R#FTC^zW6|QSuW!g6JTnG5004Lk_%BK=()eh(QI z1BnPezm^38f&xNXOjy-p^t_Xumyo-CYih;2f z9)*gFE38zfi=lL`PyFciP_6+3DUPWGfDyL}NYLjR`J`h$`6Su?86PWYCCTIoD1Z)1 z$X;1tH$Ti~cRnPMMxX?UiHTjv4^dH2QT?FD#>B)FO^=9($d>~TqnJ`rQH6a4kFh03 zMnn*dg2(z5KmFfL{Ey6kEB-H;|49Ew=0DQ^@$(<)|AUGDe~eT9Hpj$p#fX9$I0@Oc z3sJ<8N`lGmLEz~@8?iEe!2mN!ud5nQe0i}p(`t&mZoSm~HYX7gK%P(}Bt!#40dsG8 z!ir9eCY#$?`*4D6?8cS17t35=L5)ZW6CFq|5)$BkrO(}^8}|Aw72w*2jAhR`+5e69 z-QW|)dc3I0$o9e9-)|AiVL+Hd!Dq+sUPsT&PR=ca=!X$@&u5w zU4E=MF2Oxkp_7IIia0R*WY14m;crU@+gf!fB8VKYuGT?341_$E8v*PF-UoLM@wa1#!zXDVFO%x{H0Ga}C|K(FO^9eG~Z zKMiIrpZgGOzGu7{_XnTrge zH!~w$_Ut=64I*bQD4Z`8Y)m-8@EQfs>2vn8oB0)0_g+N?zTcy-B$v9HU}W@2-O>Id zRQh?=Y&=AembSpadVEx3_dq2B*mnIkQ)|lpg#K@ao2EounD6nl+skGD+#rPDRX&Jb z+M((#hwq;Q!J&roEBc;;m3TB9Zo_15(*5=H-VAj4@je%!kR%M&paM5cp`9OTx3_Ed z&@_mSba56R8To&SWITD?ag^YT%rO<^u%XJQ;C~mxMa>}&96?j$%@OCd42jCi2Toep z05N_+k|#})mROpVDAv1t3Apy8B>WC<-aLb#;L5rCmHiJW5`6X-a%wwFpM$LT;5e$^ z$TsaegOm-B2__O;{uVP}t3mEO3DZG_^@tjg`Y_r^lA>D2=OoA#n`zhzfT+@(?pA?x zM6>L9*iaK4O_j;1kiKvA%>N|38d;F%*9SM>e)oyrpa$yW>)#ZUSib9quM{JA2oRm$ ziFei-F(=L6r;qR{c7hP4NO+LyH~b6>i2=LIkM^|C0rEBBS?RF=*Mspow7fs>?PbUUSSL)e9Qx?J3`?r?AxXM7&6@i!icJ@_i;>)||C_QCK@SsWRB4 z1$rW+{`9Fq8N(HDek?x2gaKH9zcRi`j=ZjZt7b?L1yq$uyi}&s>o;PQ58ZiVj81D+ zlJ1%#LsCrppj($gHy%btFquYCmJh|wipTdo+=TD^`OPTp0M@6d(XmA5q3r2_0oSxG zW*k^6SJOLyqP(z>kaOW+?SyD z>;gl4n`sq3t!m<9Qwkj)5L zNf=CW-7@nd9&T5Io6@O`+%>Er0S}Gr30#&)U^iD~?XxtWx7*>a__R6}uIE16v_x&% z8+!3*m5czR$n8n>?(If?do@_SF?}$Qx@JqDNM%J9FB3&p9(8M;YsbSui&*Bq9-Ec! z-o%>#32@_1a2)tGtX+%nH>4kL$)LkRFsnkn+2G3y*!68S6|QvYC=d<+{`v3FRO`7< z9-S}c15+mB+3-VY>>MpW_TWfE*TthLONzKalpT<`4Q2&Ka|#zuM#{j!q(l2$`1c4t zFV-)}|12{m1TgV-#YNsMLgD$!REy5Xc3kuIz+L#UFkToRf9w5F3QNXp6YeTU0Fx=~ zm3*j=SfZ1!|IQ4F+`y7y-GJFvP%54@s;FTKapft!%}soei}_&7e>;|Fp}uMC33?iz%i+Fg;gf|D{hitBMY(S z8#>@r;0e;=gRogyJEmaGF0iYfs?xY&c>2(9-T8n$+OaBI9>Ac&XiyV?jx))C6TE zbdUb;Hg@iMTI_5a{9AGz%`^=ZWYQfdZOmf@Y*?>49!C@0KajYymqFdS96q;7(YfsR zC-6hZbDCQJqQ=he*LH;SRlnnIv}Yq!=b=y1DT$sJ90c$vHlBPH6}6yOYo(h(tWq-f zYj=1qU`U6CLa6t_Z3+u-2w-r6T{{atwC>a>gy}m%2{v3NQyI`a z0(-0zUV#}UxneuX(xpHv-2N+exu#I_C+2hiW8w3z4{aEkNAkxvWxF=)6fyx#r3hZg zxMT6`9QPqN9Q{Xksc;0d?DD*|)Cx+=bRjM1skjUJ3^!C&cjF2{mZ2ePyZ6utzEg{| zySw~`)eJ9}Gnen+wDxmILAQZ+P?+tSV%NRUb)^>Su{K*w2}Z2}KPI7u{(@a$^gj&v zgE`qf@46nDPmCTK&^MK-=yT_SU{9>wQ;y_8WOTer>LmP4>Jg^$)Xc#Z7oWgW4dpOMt7djmrbmq~1EgOd9|GEBY%3%lOJYU=1Xv1G%>G0?V z5jNB(JjD+sa9di5jrL48;5!-k;7XfkPax9$JEjZKuLOpIhs$Ol{ih!9?=2~P3*5FA zfRW&*Jw=5e7d+%y=G>6=BZDx{$? zAp#uWP=5SFfbgof``D%aZ7+&3 z=&qV*-tUl-N;darxsbgn@Wb!X5QMywG`C;zgkK+Ze zJj}q;x9>WE@iSUe9d4fy5~P6==BU00nKw%pK_`#7=XNbtjpdF3(Y@JIh+R2uPN|*)Aw@VoG4#z zrrPY~tLk_LVq>#PGDbi-V3rsn8?E;WoS8f%OBdN^g1kh4SZLTI>JpcA(Wo87rrGlZ~oxo`OjLjuV1!x zJ5kgW1R$Fh3D-xV{4^B=pv!(1cPFO~_0obDNyIB33ti>3R(xALyS80hElz2Gp}1_n zks@)HnCnNUv~k4-qY#E-&!-FVy~V0)8zAl`>e)_BCB(ZR03yNkxKf><%`Nz(w9x2c z1$&W9AU`_ZyyqWr`g0uE-aW_=q+`@9*T{diBqmlDb;STqCS5Un;)5fOS;dU0lmBYd zld&WNAVkzC-XQ@{G#!wnucq;-wSaW{C*a&#_@;)z5cia}1JNUk#u6rZ(JrN~B z$XQy}=JMa>1@aRu`Nj{DuFc#ltE8dPXoUfvu?>euUYvjOlg%?5z;(?LcGlnp2nv!R zmc#^2n(=KbK$c*0lKbOe*VsveLvC+qLd;;Z##9&_$NMd^E;x7DDgWlce670TxjxH2OW%>N+Gnwc;p8ACJ2X0b-<$G)bud6|n_glY` zWA4-4sXx+s@`Vz=AX%mO#YKf5Qtc0wmG?S6a1_OKZBNsgs)=8mF3{%^hwX5M26#S8 z55ISv84dG$j7sf16x{YbDc$w=3h8K5ewl;7s%eE|!2SdOm@bv#DCwBVY1l79=J@$Z;EP(H4jmOP*xY1_O7)pPVSw z5APumg))Kb&5tIIWM!0fGkUaTORR5*y!mQm=}7!(VXK8?6tr6-@R- z*!h}yui~@-N8oxDywi~Y$s+(RpfwOW@#fe#04u{mrqfcQ0*VlVaifj;Xl96lv_P#a z6qq!YDCj9uVdpxC&VxpXE{K((UQs{+43tb@^$}}0Q#n1`)_>&r5u|8zwsm#A#Af$x z)W93on@(@`GH_kJ_?>y!^O<})U6lPKa_`hq>W`twa<2WD8=uoOciXVQ_x+8wx=bm3nH*{1FL5P^PVkR4I zZN*1GcwH&E4x(micnM$MxYwuTpYr6;Je*9Q{WF_1STrSb$>w>nKFJ61bjkL&2)vv- z=QRvi8U#*-&qw~`5P+8>5ri7-gbrH$m@m^1{+hA+(KZx^Y!xLr@ zby$enmhOgHvPqN7)q2rzYkzEe>qDmmRE5|lACj-&fAQt|Es@t!rr(^C-OQBQaW^^d z)OS|vs9y}sQlYwewCKC%Ug6`@iUJ|77`v9pZp6>ElHa*$2y11lEc7wZ@j*O7279Sv zwJmj_z%JL(#gdw4)4g2`d@UMn&pLb!otZoYFT2zYJd z>W~7lp{a#4O#eZxjMQ)U+)#&h8>>gVC(qarvA)taB{^s-cJYy)NwKdO*5XJIN-&gG zkQb8ES5#*>8>46%UOCM{pTQiATLb$XLNr$@x7Pftn8|nv~OgtauvKOp$8sQlVR7blv=j72F>2+{ZVNM zrp7W(bBO$^0)Bk6*$^H2cXvu77B)+PC@x;XkEY*y1agy5gnir-vTbR8^0Wed3928? zjE@!{(%8{oK(U$r5(8OnqO(}v9!}z)8zzGuA~snM_BlwuF`-aUg#;+8MGoqLV=66GO{x`6Y1VqygWa3J+gOIAJ?o z`=$C($!D6WsWAfFxAa>D-(ajv2Pn_1*SmwSNMOn7ep4ie-xItZumJBNvI$>}pinGNgsawoYxbp{c5IC< z;%rZs!Tm~wO{8J?22rOVpRkOfBXW$!u=i@#gG17JWhT%OvT&?|;qd}I;kQjF*&2Py z6*{K^=#jli(Ozk~!rdkfMbIJ^;IJ$;!&0msffrA9DrY$J31OE`7lpO4G*&m$@0B{i zRh{@Nxn9J$rtl{QiPp&su7m|Gq-^y{aLGai9NzWH+9e zh|^tnH=Z;oS;rho1QrL@F&GZ0@Kam-zDVFOlxzA{JV8bzzzrY z7_U#9GTFt#o zM)SvMC$tJE2c<@z=?y6rD>(cbh*sh!<5vM6EIJ@aR^XJqs#UgyZl~(pD`TO2aFhMk zFbD(gD;Px;{8ZLM9om`{YW?PfleD)-2{K*efUsEaBo!G8U?`93XP??V3!FE)R*t17 zLfe1#(0qG{U)j3)@ch}VbU}l-NbOYz1qOBzB6Mk(qXksS{y_0Lr!dH1G%Hu^e<+Ou z%kCRgm+b=H=*fvsvn+}!k`Z#GIDkza5hgGBe8~-}QByMjTmhg$(vn>{THxOirKgLc z02ekgN5vq;@0cw7^+?C!2}W^x)uEyTeq`%!aV*urjW$jo>&SdEJ{9^Rh#wlkKt~*w)1$zs#On-KF7%3@SI(&V z?@>#tGb3_7W}1p}pMBwTe-AP!PEJOR|9(ky*maPSFh-b+;^pvpw8znR!?e<h^hKy^$%sPP4th_PKCNC<-V`3UpIc*1uy9c0ztGxJ!OP(g80 zF5R;oe8}1K2qVTr`Ip%(d5a5SHrboxOZmrb_C|cpj7gvm_p~-V2&}Cex-cicH4y5Sg+fUhLr??H^JjJoDg6e zzqPtC$aTqZYGQ6jyNfKSqbU;HdUBxpj^q~3>_`?(+51} zH#?vGoVU{lhZ_A*UlqFi{m3EA((>In@I+S^pK>`&-Wjc`;lCenMVkV29EE;Of1{A+ zyY4kq9zI2+Hv+}+_kVNMlK#zVkp?7I+O{h|^b?DJhhHA243Be!<384QmQW?BA zKWBKc0slLmCE%&e9#YR_mDVa4`Jmw{)i6f+Bp za4&7Dq{}Sg@4r}YHEnPUxqiJjjgk55GX%=YiN0VoC-(^tPATN%XHtUTgUo){4VsPd zWRwNB#$;yn?NaJB4_NYi>ybS>0ei_xj0gp8KSn1+A1~MwpQe@!1w2(vUcIESvZk0N zUXqrUe|!pa3b$rEayO#HXkvGcCmT-pWpT>&#ax^zk)&bFo)QQ>jTic{16HAkV?{=R zc)+AyL&DRPh7sSJE(HdHb@ZRx8BYh>$^bY!=Y=sSk8-}QmEJJ|+-@Z5(t zZu+zl3~=Q{&i>c2WMVrwqCR^iP5>>!M+iOH8xbyPmz9gXjDbvy*ugqe%2}nQB9Hdf zzMassrXBHYc!m;$wB&8y*WvnQEF)ho`$-J5;YhN;&G>Q^m#4vu`mB!5=y;Reg|yd% z@K;UOD3lr4+?8}{Q1;+ZMLvEi^f^wrQW9GXcUZx;hlVm5fceC1dZDr6qJFr-r#wpc z7nvMK8sR-#^cf2>1wWdozPQCG%9BJ}U=Jm5`i7|6{#px}=E0{Sun>k(^-O z3#!e%S3`di%9LK^h&u2N3)_{Xek1N&t07<3#lj#GipFSRI^mQroD#g9+v zuDl*!J{L5gf|cudGdSrmtX^^LPblg5(wU}fnQ=6^6Nwyj{9?P0;a2A~`Q3$w#&>W% z;$fkVfOpf=?Ccyr$E)MeohEc^sWbiD(HgV?2kkECdTy+eVnptpb9nG8<1}H(7gKSy z2H+zIptBa5VoNSlIR`}JikvZh?IQ0zuxhc@_|mkMGqjZQ-= z^?e-8#83&%*aUp?AfCD5Sq+BFTVo3F!8O4oHvinAUoP8-UJ}x{Mv#a_5BK@=S{Zv} zL3fGkJe67?n_v!K4)2F}1C4N0i!~q5HAAX0K{buOA8abPdH9nZ2i z+9nJj+4FT+9jtWd?2S-Ch&o}3rt_jZFI8ux+G;k{`=U+j@o_$t&X<@0$VFQkO7E13 zi1OxnIh)X8nRGq0;m~G12AA@|EeSV<=)nU63u~8s*BKx4fPA?N>`~Z_0M!m&KL`kn zgMW7c0D!srO;n+7aIyWy)~!mf;^j}8(X`UER@Fb&{EGqA+Xo!l zmV=IiBhYtF{^6J-H`}8YOu7Npf8r9QKROb{FS$sMR>YmJ0x}n@Rs7WjH+z~74{Sxl zoJv%N=wMDYs;4=oN|~8Ri|@fLJ3E1blEW8Spf}RYvyRD;O|?Ob{9R~36DT09oD1cE zo+4eYHk)1kY_` z4Nb$Jl`i52XT#$~(LG9*{KGzUDtFu6o4m&f+Eu*AR->~8uvuis`#wtpx@#;~05gK~ z62~j`4sA-n@gEao#NH2~JS_2iSr8DNqe!ZqZ}LN#O44ddH&gf!9nFD)yW#QNHnpGL zZQBngte@XiIP7BOtp7Y_E+F7s%r#b;&Yw)vm4ManuI}*^RoYa|!D_A%T}&6?uE8QoFy&HU+s6tsirfI9sjjZiM&rRex4`j5`TN)@dxPQG+ST4JhIuTqmoVzbUgc~l7dLgrB0`hQ4pYr61Kf;qNb3Op`50ty7~Nj3DLdE$b@rUQ4wNZ=75L5l{p z_PgSxuh+f^@pQMuR4`cv%#tJHgc53n1F4!6T;l~X7O z7S^03z9_qc+2msNvZa1MC}nbeiL4N`i9hO<&)baQ+J!1nf6)3BO|=Sqs>vwBrMnh` zm!r$^0gyi017Qcg=|=myZ>?}>N)Xu3EP`s$LN(-tH_mFH>K{MQZt(TOT7@N>Jb8vJ zOZr^Sr}N@oZh74E-xoB3{0A%POMO5B$rGK(35`RtulJex?3u>(g6zqKim%CK)A%mv zrxaD*dz|MTCvdGTz~YOLV){}&Am^jpU!2Zd|JhC^5emms&RSINB)*i4sv5`e(<*R- z*1rQwd`W&FC$qTFZlZ+$n`68B^0~Kjr9rXdnp{{9hE~^$#|9dH(8hJ~E+S6Q z$%nt*d@I~{NiK9X&(}&~D?3swxI~K9^x}o0L4tSHys)fFe zax(Pxx5wUxz4%gp5Gjm|tcWtp!90h_#2<9t%v?N`D$TP}Kv4S*LeP)I=; zwh@mf2i6oBvd&iL_rfF2JHF5VdiQAN%Gq!1KmJFolHt%5%2VPulsUb_fs-cMm{Y&TCl zv9LI)$epKF9?doqef2DP$-g)@29a+=OwTv)wOY0yM@*vXk;>py()kYDqFK8nZAM2! z^37SE-EFLSG<^l5`RmdQ zf-E6n`dvwz0f1oS!O2yx<28J(JN*3p6yl&RTk!1ygI46G&^E;{FnKT^tyHNdz(7hy zco11#or2K*I*8&PF>*$O{7a)kDq(4XHZC#Q7vGzlN?z}}8jc!#KMj?< z$wU3@`O}pb0n$U=7jdebMTh?z45=8{l<&4hr|AYJZ4@W)ycyr6wmPY7nNYOCyociC zforzFSIVQZqr9m})D0P0#A#xR^zrXC#>Wqo9UV4=Tu^_|1)zMDoj5j|h4;|UI9N8H zCzdX_ynb2lc}=|ixoJ@76>JUvel{-ck80IHmejM1)9LgKSH^}ei1wl&h&v1TsCmv8 zO$6>YB>Yt28=y4ewPL*`%SI1gJfSBX%~seb|COQe$@QuOI^XemG_H5O3BofvH|!x< zCI3RuZ=Xg!BPz@L5P()bCo^=g(Exv0)6|83y4KFH`x6o3U(=}?eQU10{nWAVWud$l z{CC(qis4E5MDD%Q1doA*VP8Xi&G7if$IIAvbAtd=u_3=I`j2-+BNpT^C-}XUCT!GP z-%3wU&}FtmwdTypm~iefU^3gC`k|U<+O->!`<};h*y~)nNKX(kX@=mJs}c+pz5nIO zc2390*UvfHB)UR@Ty*78QQuNv1SljI-otmkpl-onCCwnLr8V##PHL#7!J&v@chE7nzvwUn6`i(VNmBHD*Xwi-$K=;3soAN*k3V3u%Ey z*}{KMr08RNRdw<7LWEq zZMF+i{n7eTF}Cx8_f>({Owe?8h$?KPhHh|UGz;QcXNBG$DAT9$cc2Rq(cxY=rT5`R z!k(!Akc8`k@z6kZz-UPYI_YYb$|ljGRd?~sC=mv0rCzMANKg&e!Z+| z$VjV3#M7r=L~ATF{;${{ex%CP!tFQu+mYxYmdYsOI5&Z$$XJ|O{3E}^5(=63cK!r;>}{x?G!Z{ z5iJjXhiYv^+VPG!*H743l36(O+t~2GrUqlLEwZ3n7S?Q8DUAB1^Xp{Imnt zIC#BnCv+Zun=nnqej(MDraGKVo zovaSC!}mMswv`^BtU7G!B&4f|032Un+%7`wL4Gh+<&@^qRkV@WNM$%tp}QE5YxMH;6(7U?C?H zKUoA|J6|K)@gv2P^gO_C!LclX384@L-RA(Z74fNQ6zo-t=dM|Ff0J{NVL*vmafuJHN%w zawbG|z{l%Xv!VkyB%BPs>w@N?Uz9iI?$108Tnx1Aiiwy4NRQLe-L<=yJ>}V1DWf+5 zyclQEHLHaoVHr`^icY=Yvz})NQe_pdy>pec>hkJJzs%gTD|<h zJRfS>(!#@4zNc@Vvz6au+X!`qK%D-a*VneyI;u9b1pB8m3V~+CL-x4327D9SR%!Y8 zz~ai+kaa*$Zf%e=khpzgkC?}o{>bY^}Oj0!79Op@c8ey}$ z?9l($&$j8RczW!ACFL*;>aeM|QscM$ytei6sqDme_nKSXK|vS530#8pulM?`Ej^Vd zK6+Tz+i>@T{Ts6R4R?nbkA%|lbfwSNBq#K^iOGRO=9<^Km5sm#c~pJCjYkt~-wR4u zNQ;1FR=+y@U9sj~SZ&y*TN6y{6@e`a9ShLn@YYG?=jE>K`&RUtkIh}ENKChFt2jHS zB4~wzN2dcr>Uld}7N+*UYe208+pjV@Z!&#d-gY>P1YRl9`^_|Oeb^QoJAsHM_Uf&E>6^kQ0X(v#WIe!@M_hjRY69hHcVQ6wEY!-&T8~ut8+AQvBU`cb3t@A%>m?{AcK*mK=DC)GWLTd!rM zx44)XSYGTzBZJqxIG*oasqBCDFS)Gvuc`ZiL^0dp8K5vq-F9Frnd6m zsuHd0B;F*SbNOS~ys+FN!o%gt!;Oo@ORKzUZN0?w8q%*$-_T>lyDgdR@U`voJTG6q zIp7M41CLKR_u1wiWlDBaTk!m>JY+e2-~@|B;;>cq;H44NU)qk%M4*Fcym%k%8PFg< qbdLaV7!SHf0GoSoE96@HpM7@JTC>9V7Fpna2nJ7AKbLh*2~7aDmQ-&5 literal 0 HcmV?d00001 diff --git a/doc/images/google-small.png b/doc/images/google-small.png new file mode 100644 index 0000000000000000000000000000000000000000..3aa337ba3783a0951ea4f0bcd7b30d472af72319 GIT binary patch literal 4692 zcmV-a607ZrP)G?nAA!FDS0v8&Rj4P9 zC{ORiBqSun5ePRQ6b!335y9pXp&H221}qh*NGu{kjt^EQ6(%7e@pHHzPrCvUbrcci zJ+gL?#}bREUt(qSmB)5r5)z%@ere<%mDwX7&1Rn2BOhYA+1Pbh#cR8AU0Z@ z$TzllYPvs@kjTJ2d*oaqR=NvoWBCMf`}|meKnUk7u_{=p4r3A$y@haa?`ZPR1Od9s z{1wTEhfh!UWfGEb`Aj&9GM+#a{woe4oU*09Nlm|I5|VKFNVw1hHKauyP1`1GRyy@Z zUq#}0CLsxzPlQ9Bj3#)CaZ)5#(!Og35#!=eWpd!)N%37-2O#ejj3kcrM-t~!G;#`7 zErXx2F%WL@G6_kzd?6g|mrkdN-D-dYE#!JWF-^MwZA( zNWxhFeQV?pn~VgY3Z z{d-j`U%KqLLdU+_9M&flEjXEMY-`l@LaNI35!wIRs$8q+ z=T~Hl==aa1stRsut#T>pa<-$BaR3r3T#CLmG-uu?TY)-8`uSh7`H=0=>q8OY-d+ip zE-UL33%O6m*wVPa91AH+Fz#E;vz&g;r{6tL7xbQ8TYAFIxhgtVatJqz^#atE{%)Um zz;kcr2B$@>E~wPg06bP7EIO9igw54`4`_a@>2<eOHh@;InM9s;5XgkJq^ez0>3X|c zTO*lVXrjz&1vI;ipxZ$pqP}1k;l3Gucwbtv#?ja%O^myW2X)%*c8RQR!+FAFy0QDR zJjMEe$4fVCC0W`+Yl1t!o^C8 zj-YW`X5qNMkCl??0d4Jg;$)VmAeZ1Wx=F2MX?x83RNCh&Fg~Bz4ykBiIS68tZ6VU# zmTQ+bY%wgZ;BSLDMmX(3yWZk4$M)ifb-qCd>#cCpc>-mtdsm`@+-~Cqd13+2^>6{n z(l*nYpr%&#RQRwKBjvz;nocwxHpYQ zjRRx8D64k^)jy{_-HaYbI9NDwJ!NUz652LL4JldWK=f#Apng#X*B3V)9@K}%KPv8n2UQ|NbBkh}OYy{-ygrA{-vJwBmfpd(e{ia*YvzL|w@9g6=f~%B!6l z1wE|07YeIy_e8?MrTIDGl+SoiqxvJlxEC;mC zcX_Zf>Rlg76|8dEf?NXO{}rK$8+Ky`fzEC8!2jFr357 zcl6E2^4i_?M8b7eC)E1gMSNGn^A}S}l~B0bdI9Ww_+K%eH^{H3fDXtGyxi0~RS38Wk7T8lX{uxD5&2prZj zk(mQNW15|Wg8~FHEdE}Lud~eT1K%cVHxRLGfSQQ`zRHHf1t#wqmJx13<53hrEHi6x z899SAnsorrf3FL9FJ%@0kt|Nip?$5-z1Oxu{E+ph$-Y^a2#6wDfVS$sBRIsa>D(uj zGdG&7Xd)Gw*nA+4It;E*7Ns1vgYP?+lt07A{$>Q+$9+gmYffXHQEx@JP}^V0MBAJ8FXbQOni zZ&|3z{M0VOk^190z43quq6;O`5V)fB{+1kw)(TxMs!0%BF%qeS)o*Vn*5H&?{XAUL zue5M*k(=XAF0@uMrbGnyir8CSu15f#)Mz$sxdP))LP3AZJCrtYgUx=jgE0nMFYU&NjjYlv#S?0e6q~ zW=DZ2WE*m+5}`R1jRm}AP?v~k?t|DSKjz30X!FR31A(dW;uiN zd5?EkaUxs-SO*aN4o*I9n|p$gu9y)sl2YkzUm#l`)uVI z4>(21Cc?qRf6ZEPva$qjCzM7kCuq^i6z|j;&Ve1a134eg_iRV9*2szjmxEads1@`+ z5BZx8`pK~zm+}r~A!x;XSa(IOc8$t2ml6u2XDuL_O7>(t;I^*aK4e@1Ij#>-rqPYT zl7Di(UsaZ{?g_jn7PAmoF-4llHv~;E_a@s%IB$gsqJ2$! zv)dbn7Vp8VMyPW0p<+i=C7$daS28h%_uUWx$}Tci+H}RrkX%^=B+>PAeoBaC0BxlH z3Rk!mNoz~e0`u6;yqq6eI*6tP;Bs2HO;~DQ(8M#!a7aI-%z?~L3+dYgC^Q5c!AwD^ zy<@Iqr8-Q&<^8IQzz+QQPdP<{V=N(CiE<)mFjELUZzI7H=HlhX$nEr5<4zV{!bw4V z(^;e-(23jhg9-<=*qQi7P{U|Q-_eje1f_zcd)kTo98wYj#~vd(c_Hfzakrvob#bw9 zqgxa8Rv%;uWW1?MxN6sX7BnL&Ks!B-`}zp$1c5Rd(*Y-xfZ463SSLMIv>j;Pp0jwBJ_gusM4<`yFFLG9xD{Nh5k$x`WDfSPblSQ0>CC%z>jaEyq%%w8_MoQ6bk6( z%envwl*b5EW+xy(W}?tae_DChrd9Urt|N%aM+5?Cs~~n3_M#>w_easHlNI0SqAf5+ zQR*uyT-%2RQ6zjX@4~bBw6@f9JTn3iD2MWx6e3t?!H48%>6G-&T^%BV{2x|R463$Y zFk%Tq!45>91UJz05dh&8ouU;3NGfo>ReDd)Oz`9j0n)Xjmq|L>U$@z8OSJG9O*hH0 zXyH78W1r&!t*yE324%@jOc1nByHX>I?Urh*?lat|qQ9s2r9{Uy$HqRORB;Y_ z;hsbAem9dXrq#NcizsS!PWBn{It(z@_~&*9$kqlKm(e{o`k3}~!XcQUrEhGtMF^6a zKD?0v8-LQO3!F1l0Su`~?#C@>Heh? zxg12v=#>_eZCT5_LT)rsurX|-*|Ei*E;L^KO*fsMgNN?5q9^AF@7F`hSlc9F$iHMd zJJr}gv`^5#1W#3FBM{on6phwcpeT=I8pc=x^gPLw)HIkRS{-WLpg3U`0b7hrMZuQb z+R>{~=yy5CRl%OOi*NuA-^fI-+s%!?=O61sgN|WN0HmPEV-T&;wPdDYYo+&pNO`q; zC=`7pa~Rm%QWWcQx_Ju>-TZ*^A0#@_S=qMl7G=BPR8y=_hEB~)p{Ko&$4X)NoEbmT%oalFLqVciKR z6yOHpq)QA2o`iU!flyn2s&A6bq5wz)X$Sh7Q_vWcZ6^SEP)JHQb37W!AR!7R0-{bR zXoL(5;bKh=%|=lxy?!252kgTfL|J*dkorveAb>}o z9E>x%E7QN{p~8fO6o2k2IgVhM)p|`c#}1?J4T5&`lq8!Al~Sfn{wpXThByL8ULgR( zBqVwg5Qx*u^FlzuBqSuP0R+sPF4`lRghX!!_Sh(gvZ&ho20g+-{>LOFdZX)?k$BC| zL?$<*c+>*IeWpXW^+>og35nhU-(+OilaS~wK(37^ zGDnec?&1lxM(*DvBt%ci)$Bf0=o?BK@e7RCpydaXkmxngPSI%=(7JvUj5HPVdgh?g zL!?XxIb0?ok@qlPh`boRP)K=nF)BOIRXHPgf1a#Xu&{U)8dNa}iF^gWN+9u?LzN-z zXtlv4BziL-y@T6mJK@mLSwcdBG4w0Z+Ra;TTabQW5)u+w`-P?~BwGVr^CBi8A@TqA z?%)JK7zm>1_G1<4#wNT7_FxIA!g{=?0rx|WLI~dEzaEyrhS__d3b$IkH~>aQ{*xsncG>luWrf zbLY*h_tsta=WN*QI$hG$)m`0Pz3XsZ1QR+2lRyX%I>`D!tU(|}MP?8(2n0d`;ljXy zV4%?BO$>wZ`)L7%N&bfIpfKa#G_X*Z69yK90e!kbZ!Rc|4}HFd-jHzk-@GYMcn*5A zKm|U^{PUDoRMnsaxY#&3*f_bMtQ?$zLL7oZTpW~~973D|Lfrfy1P~|z2?P&i;^5-s z0OkD+b8>LNIe+iO;(!>4@QxQNB&zyc2QJi6uohWRd% zDrs{3_LRBTaE@mBVhOD>%>BBnOLIBS5B2<$2*XAviCy)}GNeG-ozEHOI)03`rYEbR z9;@hvH-w6BNXV_mue0vv0L_$d7n9_l8PEc|C%3Xb>~}=&#!f0LFxgx7{H53*1^<7Z zRagW*CDT2(_9klGB~H{4IU7l9~m+&W!-}UKbQkwAkW~Z}ub26Z!yd z=aa&Z9m=NS(m5N76n0n~u&ZGLwX8FZtjhgDtJ!=v$E7dc545*~YVA)L0`pF0^IBVV z)3sAyc^?to?>~Uz{kqlMUmZu>8r?XdvG$f*1l-1hU$XZW2s0Yy;Crjab>V*aj8$T% zBsrDVik0?X+XVbBU>3BFL(Vibp|#^09R#6n)x4thFt)~_-)b~0Gc?)p4M|S&aMrCFW(hbf;v3EG%h`W@b-^FwIQlj6Q8^Q3IHKI!?A+4f@f z)9!B72p0C9(c6&vjxa;y&;3pS^W>NoWIv`9MRHMT)KmXsThPFq^4m9hjFu#CGENXw zM)AQBAW!-a&e+d?=<7v#9BFHWhMnFOWtB6mLhy^9(J0}!7fnGL=YZkEvvHyyUo=j` z@_3H;1ViJF-ggtx?T8g&cJzFtfcwc~%285SU7+Ya%X*$jF>^uxQS?c#5L?*Ch7qCy zYZKp=kpvglmh(6S2R+x~TAoIm+UU+>PXyVn)iij=*1JFgRK^U zYMiM|@A*h64gWelb6I>11x;Ce(4YGa??xwqzman-@b2&%h4t&O?37>EhnkJIWok>F>Kq)FxiqSltlV$P|Vn|FfQA zg_j~m44^c!bsd6OjJTF11_r0bf4#fYk1(}y@zpDs1MiZNV+!o;Mov!oFnCSMn2p=B%ZgES2LiFckqe|Jk`c+b(N<1 zI_f=E2D{^6!NmjUVCskddfQ^qpm!x<(EQoEIDqWFBlz-GKhI9T)j7C9wCU=!57eiT zH?Vyv*0KFQgl=Sv0sJF-A-*IACkG0d2Yr*KWz-}Rrb}Og@?I6)6DXmnO7~|Tkefl?vJ4hrozz-QdPF;7 zDFboCUqO}U!q;A+id8NbzpB+dPx40VQ8apAKIPe0w0DYp0O`qvxsLf>N42T@mf68c zO+@p=EYT*36WOm8xHy7CuRPOp%U{Y*8<$YE(yyCKlGKTDwCc*IWWp4F@ha1?$>=v5ngq{ zl(s1R0BSJx^w3>L`fzzwt8i))3DmvdYYHJ5Bm;IbE^oge<^S;8=@Sb4I!e9V3W;aP z>5Qxu#x9pL8%S!4_eO~B8)nCe%+6~y1`bZKV1K;Y)R^0Bhe=wTZ59rWX%f)<{tFOm zJXGn%ug3;`OV%)$U)pPOU0fWX3FCZ?bS`$;?{nTK&@7~z^@dIS)i4l#*BR?In2sWM z=67G3Y*ef=jh~{M{_SIZ_GxpB^=u7!Pmfj_XBMG$Fz4lXX|eqeq;52j>T~Ugr|KLs zo5%6rDf~*;;!)WOJ%1JYgde_ZKAr*MXq`_8_L}v*uK318`^xbZ1X*08VtByX*Rx{s z*5Xe{Ma9d_#hhzPYk`(G8PsVO1QJ8EC(p+GjZZ}#$(32rY3t&Mw?0S!3uO20VvWh0 z6vX4zh;?<#7;&#Gvv@k6JKN>2ts;m=)O1X`f8hBMM-&`Pb-()nA}+TuRdfTic)GzJ zE*JOpaSI*~AhGWw-?==Fu=+Yg+PT`M#>l_tY*_zV^b#NwhcOKu>}e-am}F!3WR=t7 zlrhX;BQPd=KcinlZMY}oqib+fO{o^VrK@R1cZxJHDm8v@IHw7Ts#B zNXzlr1lerAc;>#HT^b)$0gA)RM@8?iKwwV-xLi(G*$(KAXq2$9B^j` z8vC-I!n<1ri=7DkH?M5UG|f1mM)IUso2qs_K2pdf_N08}$Tkyyy;!;=aqs}rfByg) z?zngN>^&czs^p*P2hMBF~om;$5FB+7Ua}lv}c9f0HtxB?61@M4gYs z2i;nJHd@tgjpA~DpK46Lu<@>3P0i%ftw_h5+|u(_a-iTH0NMBXd!) zCMABkn-Z3;AK3rGxhSKxC-GkU^+xKOH<=Kz6EAI}4Xvx7HM#w~s@xxojRaN+pcO|$m3cv$(;zRb9CZ2P9C3K_0M*VomiSnJ!-1*UXUH($8(hj|W> zrT0N_mt&q8Rb{|>)lG?g=FY1Xu@THgQ?oDFE&>FbIUF|2RpVl6b)M2ou2(_1CJ3ueq|#Fg1*i=?@s{*ooOd@OZF zpRctzS9TFOm&z;o#i-}nWT_ga^K31hFh47j7hRD?tu_EBgo#d)siiGf-d(>yHDN1L zE5lS|4NSOxlmo!PhduN6Y|MD+EdI?xoJ3!#%l4<^d2;pNC$4v5^cVD5dEuH0+^kN| z>+P-CR6otjynsx{3Y-mEeyY^-wBF(O8Yk_9IMr@vS{N=2(SCtgZm5zdPE&;~3L3t? zN!3OY9by2gA*(<1)jG3FetqNOryY<(xy8{m{voUq+N`)ec6bzcYtd$8paLdNa)Hj4 zhjRn#_R}#o0YWsntzR0ez5V$(GR#3~3tz~OOkVaGBrVWhijKXs(wmS*O2%g*H5iyB zBKSPM79eE!X3B6=PjIp5MSwtQWS0efO3H^^vs_oeK;?7jv%&OKkAv$s*gkV#RHrB# z>fR@h#=>TQLl}1-?Bq@Vm0Ro6BP}|TaTKP&f-HGKcF|jlEpchNx@SU|Z%LQ> zZLTgx5iiC3_}N~xvGPksc*1u@ob0l)T{Yf7A35CQPe~UiYuu%;6z#O=h~LX7qxRB> z9nG%6FWeV(Y=(wL)Qg)hZTRziu6tI8vCLK? z9NvQ1zRhS6#E~eMn778v2-4nfkG-c?zj^@GO=HM${CvLaOOW2}E7E;+r0d`25y~OI zV?6g9SwtK}y{-HsIsealK^GDf?(i8#r0Wmrk*l0H+-mWyuKUMlqHbo__KA?lm+lXx@2mXRThD% zaK80LCOL0J?Jo7&>!U8I4V7U?E18c(2bI(v(SoCrw&%Np`LNZE|KO6qkFwH z7S)~cMhhT)ch|6=x`ALulb9Fe!DM4LkZk`S#w)9{51>y$&#^_@M{!!EX|eFQ)k(-I z7O3P!K>9u(+|lZqSyI7_~vW8QOJ4<~~)0 z!9;`mD`$ISb+p@=rTiPZ4LX()EA`itJGz^gwprq7dF9bwDldtr28X>Ipq**N0^;nM zQQW1&x_2?3MxO#rYosr9BQER>8s;6aTcViYUw^ZD=lJzBuHmko$8@-^dh|qWIi$aN z`#Ch`x*Jf^@jCarm*f9Q$M!`TwFvY+*g<7yGsl(!E@|fIvciTW$>0Pj$5!bxG&|h$S_Ck! z_>{zszO~L~aG==GDzaSK!Z+M27|`c5RWrBt#^zhDFl@#Z7}t6^rKFFHU zdhfKI*ZFq1grv?>koi3QWk-0jRWFUtTHt{)Nm73+=~+3NY1wxy&&4W7vze8PYw+PmQ@0Qlm--c0^h&fN$X;-cvEf|Y*Y z?PmcR;riT>sL-F{7H7UyRW|(OzHeji-pAOGaSn(wHL?YT>>8f4J%CQcwfeI+O(7gB z4lqR_itot3*$`y-YdX!HR2kS0a)d?fpmpfRoF5t&Lku7fpoLT9v~6DtUDQbWuasnA zAn_bKAjAxRIzNXgWvl}_tvr2IHA)mbsPbCks_3TSUEayQ*3MFKCso2hGZsSiXoAc; z(}6Q7II?-TXSro1i$6k-F<~W{zlSRK^i`z4;LCN*n*HQK;-dg-9~Mb4$7q;RCE0qk zDXx+n<)2CS5ebE5(&_B8M(v6eYBS|||1*VV7DKV7=Z_v2cc6W(;$bgJZ2&z0ly-IlQ}VF!uyL?L4*&jc%IlI}gyL(u%D?&f;Kvjeiuyb*62(WT+v2qDMYW0`8jxMUI|KN0TXZx)@ zo29cOyH9ce8YL2kxKrf5iM#RW7#99?tHz&Mweq7^72T{1E}I=&y$1KbAo4Atd8$>FEexqFy9S%M!CP*OGs#7c;t$HKyzLy&{jnw!Idm4}y?ht-1Lil5cW$`Z_F z#S1p);o$qjukHr1fjF5v{O1ZD~=0-8)`<% zM`H>}J3BbLX}g$Pf}s+r|0+{6_kehTWgs?ScPLN&ELIK-EH!1a1{OX5B@*QOL#F}p z0XrDVKvi=3t4aY*0cy&Bbje?`Rh_M%qPhQbxuUeRhMTiB!~u$M*OHf_RFswGgUaP& zh3fp zxeYrM{Ozhx&-`zw{ub~LjJB<_x6}WF`wz_@UP(8=!`s=-UfoUI!VYZd@jtLjy8TzG zKLWJCHV}6YH($m7N$-Cm_{(<`pd8L_zJJvP2x)^||E|^nLiu|V5i+-YGzhdi&Aq@@ z)PEOR+L}ArfT2?>)aU;qgZwwXW^M)MwBiTzvhrFB@Urr7^9Zn-^MiR<1^F#Fp*_#d zZNKwTNy@6d^b^0D738Gm~T(;q@_TQF3czq|Q&w>(a=e<}Vr zUqL5?|7g_XWG1Be=o;FdE-nrbu-o6wgVO)z{U-w855<3!{vHU<|1;Nr;D7T-JG=No zCkeC8wk=)Lol9r#lULU5Tm78fzZ_)PkbBGlC4 zi-2cH`WMtA8#x^s&~g}}UTYDEyZN?^RQ2DGtd%PIez$S<}XN%?hZW+i4E}vL3!5MUM&~7FE3* zTLynxzhf&3FS#WEY)&;mYwiesvxoavxf2~pOa9Na?LBlh-tKFEEw3T?GO~|FcZtS6 z>4KAafYInXCQ;$!ars0We1{3!>efC}H(V#(Sg36)mI`iVgX^%}opt`w5xDHH7eFUg z9$qk`ak1D#9y~H=ncLy@{^Jz|ZR9#w=|Zf!1-x~Z-z&rj)FAAPQfGSU8b{yC4XhM= z4ZvdJsjHecDb++TZca3^m8xXDIkc>AC4VVBZ8$04$;Ddy|N zy^_PK3*7jby7YV(G=-@RympQE)K`S`DR_$zliTW!!kArtp0h2vIDIF6)sXR=D}HRi zsgrBuN^IN}xNw7Z5uDQACDFaj!uL03gSm%dY$R?^dE7*M$S_LA7kAbsjPIr(z2*%g zS4=#uIAz!040i`C-u0Y4cNZ1(XzYXzNP2AOaD^sO!NWc#M?GHf2+#qF2tvZ)z{Mw! zz^23_q=Jr3Waz+!L4f`B=6j{Sx|uh_%rBt5ZLUP!Y;VOiLz-aoNOT_q?QqtM?+%8- z+4F0F;2dwI7fBy(0@G=O1DBuqjI(EeVM_kD6~6KKDBAQ6I=g)WtWB52kYEL`yKcYx zcWj-{lO5*Yy*6T`Di=Cq)<7FLmQz$}Ul~N?Sxm1Oy^Xyuw2Ciw?<6q4!p;8W(MTdO$D-^rE+IV*3!`=#~ciA2{I&nE3Ul?OBQ4wC_ z^hzoIWq|K2j11T19A~?ahKbmyFX5O|Ceg(Fu^B(_m@~^xzzK1)L(MC_=9+Fg)vuVD z>1-JDHPX|Il>4{?&>7@v@$5ixW>+_WeYAQ&ui4l zPm6Z8kgR6tt-GpNjzeq_1_TG3owg#h*o5=}5-TZU{)9vo^kRceY*`QQBD9WM-;NVf1_bo zfyqF0&>sr+dF80&VaM2u;TW}4>NG!{;&J)D4zEN?I%prIcZ)|)1mKC+m_NB|4e;Uc zmEM9omgeXiHAE5YFVo(#!SXZ&oG_my<4~Ei^kSNARGnCc@u$>b$M`4YRB_|#>wupO z-q1vvb!qz2TS#fKW#~&15MW^pMK%P^q@(+r^Gzz%KPxyH8tY}CCQ3z2sagMoAB9Vz zhEd?PN`T<~WslOl%m|#iLUl$K(XTYDPuMyLhp!JpK4l0A8^9y2LanI;XxO3SR|W@~ z(Fco&fQ*QQ1kE#qUNEq5I8@M4jLWHsNG)mZipLe0nE$?e5}!j#%i;u^hTE;6XPtmj zA}DF&6p2u~ZtC~whR)dHFuxAOrR{l(RyicZu0*8=@SCap{A_~s6{`mYk{Y2y;9M>yvwde^a*$zN+N6I{|N@6eKxbQWs)2M@h zj7yfNQ>5)`F@=JUU7YQ8oXt8}sfT0}tCeB3Tl6mZ=}q6^VXp|!j^l|I;7QYd; zQPaF#NJnzQ=!CfTjENL1|DRk3dW_`Xd^IIl)H&$|Z6S?C< z>v~9r@4F4`tD5OmXbHXU!dw$#ax^+wX=S|Z={Q-<*a|N`{HfXA&@x zVF^P%A8``0>#f-|v)-DB>xGi3<3U~b^J@^|O-}xXWZnZvJw)uw$pgrfwf%|L=#!hF zQHKpa&z$dCBTAl!WIvF6s@14Qq6%8$5}3iYrQ0 zZ1Fy$fjQD4J58yJXLhmaPhKB=<-E^n034mT*;;5!Mk&Z=5MFtn8pP)3pASh?vGS*MoU}?>AEm z)8MkN6`x)%se@cdZLT5tm+}W{x7rn&u#_^%K5rc7h1U~H(&rM!R2B7@6G%3RE~G>a zf!l1P^6*N!-b2i3ss&;QU7nq;E~~(rKz2KYJ&9 ztFSr+3zO3xHM3Ae=GlC1`_{6&nOEC49cRXoi>F$nWz=JKWX$@XyX${W%{S`Wet$qq1Io7h62!cA!nb{`rk1_f_w)t zmd3*ZJRH-XMW^VfiI*2gTKW;TZrAKXM9Xvn$OK**f9~EK3}tC2b!)8Jt^jb>MARLb zp?OrvuJ2*aV|hvzbQe9H$nG)~CQhY2xbYFeE?R>aM^EO%CVnxa6cm=cIk0SH-* zBg1ewkD>_oy@QX0Lm`1@hf@=3oP>O9sw!rUcmxUs&IyjFMXD)C;trhrftsvDv?s>4 zQK?q=^>@v9|BsAOmr~WVKmC!pAoB?08g;a zp!vsh_@1*u&WuU!+~O?<`(F06hgn_ z+q8q6n0@Tt9HqJCqH=5a-sGd|!zX{Vft#_d2p&R^MSb13%z@;)0LU7Sm*7a1hLykh zWFyZU$5$Dph$;Y!OfR7Pbi%MU)ug3t6Rm{fGj>m8S!>ZMvJN6WA}rB3ie|LvPXG2q zF?x?8S;{xv#2x_$0dhp@m;KKO`0XA*wyE@jK{Q|F?bvzg@})^7i7aQEnf>4&Fyf;?gcmYE^fU`E!#Bg!B`wj|AmCifA5%k^PZ)IO;rulM*f$Y&VS?{_0nQ@#w) z)+I4rzMW@cmNM)1#yYi4Gn_y~AC2{Bh!hoC8!Gr?2i zkQSfRD6PeK3QpjtZHCw$Pvkd zi2D)4MMVca=3T7CjH4V&Lb5-uf+iXaEv)ZEhz>s60RoCrq;HtLf{{54Y?)ar=2COx z`haT;9V*YAGapcLX%v+ZP?xRoF5pt$z$wkA-JW9|n#f>KEGN`F8VGANL?=K}6kL}o zpA_~LX;A^rw3u+_+gd|Miv%S2RC*x=g=1AI z);uqTZY<7>8WHk3{aot1f7y|5=9G%Ar;_ZnA9JP&Dtt<*ofX~TI5)Usnty5!Uh2$B zL2JTOj)*8rB8{g;ui`uGh+S(2J$$HiR$~egEO;1Sypr=l67g0wtyE;WQ0i@LPdW$)Agq7?XXTE+_Q-;Pi^Q% zi66RAf`f&FLqUT3YoGKOsKUYH;8JlS;Av8GNSZ%xl1{Mkxg=Zz6DikE5ec*`+RP`!R$ zbn@%2qt#)NTd?nYc;^?Tm%RwXWtF^a$Eb11Zrro#0wE@)Gp&|u`NhqB6{MgZ0%50P z>+(5NLJP5>3S;7&yrGJcB7oc4LxPX+1V(o@GyNLNlsPIzt7*q9snTnJ=3_AT;K-QV zh_z=E*fG2C_^4}H{e>}G$ZYG*I`It?ecYS2N|*!L2CwEeH{(*j`Gyb9>Pus~N7k~* z?xK;ih|=Dy*{Y!}pG!aJ%tf3LUsKV|;H)9PWzcD0&^?j!CI}sj`U(gKZE;HdwAe2Y zv7}|2s;14UDh>7$$muFCiym7o^@AO;mD!_abt>z9rl#t7AIfc`GwbgS_bWd9q|D2n z$z$~NZ2j;)Z{bX7od)5ou~qk3#U>FA=9B$X8amZ?#Tgq}KR%w9`c>y8Jw?Xg>(Co# z3sslbX6(%%^Tp~uL5&71`B+wmMRPl=QZBfp4h$0eW$dho+?zKWwH+qLnttvA)zakacs0VEu-@Ks@W8q`Ut_EDk|)v3 z>GQ{s1N$+?fCmtekq{72{=^(0s4L-7aZ*can!6(4Lt_r8H(}F2V~+e20!ms7H*N{- zx(yy_OZOyb_;LDoM2#p8dpBVE(YYlHV^ywAJ=9#=f7yZ8h_dXWY5IEn-3_7XLH(V4 zkOFFLD?uvXl`Eb}`%I7eIR(s81uQ@UqJ07Sb)=P*wR{)ZRns6L#tE5wqY@=#kD@V< z0l%3WuKzvyo0F2Bpvc<2=H>#St4xKs?au4fW#r^!1!)C*1@eip@tR>?468#}_*(i; ztQ++*uv;km`QJm*LsJu2LH5#vi#ZD4W?#`s&#}A=ezEe(y0$VD?vaWm{)@kE$~1^Jp^5ba#` zG{kIPgbbaTBxP%j;Xyn^Q$$Z#U%Y6$4Z_Y+O0k6}+=AW4=PiZ5%4O_PFnj><-&(#k zWEq`m!nZd=ny)e61UJx@%e#a#EPq-H`N{MYtJ?LdY7xLanXGM%=f123Cd!BU4IM>` z4oNn%MGFFhZGXt;%-=mr(9qJCx(kN&sa#fnB@7|GOG_~`7-}G4Vz6Fm6OaPag(ZCi zEt-Fz_hXIn%}}Ruvr|1NsGzYP7eMrRE?SFc1>NGdaFM=cn>|)G&52v@{1|7lZ>y!t z@XN*qqf8CM#yW6mr5XjQr$WUU@-gpe3dd(mZ$>8csJ%xb%3CI7R#nTrz5r=x7) z7o`-8)`o{h#P%doo}Ra8lKRu)y^Iz58uqVYONsNi#Xfj8F-niJu*~K?LCzi`^?=`$V4Voz?&wy@-HeR*njgFO7_o8gc z_gj+@sBbBr-+n;^^kl;6&%VY%znamL>z~?EuccOx0BgtU4e?nf^)V^8r_OJHT; z=l9%(m4v>s4&Ff!$z@5ox521QDD^kMO8K&E3$q+!6!2!yz8YfqJ=$b)Jmq;lHw%^E zEc5#>$2|;sEFYo>5%of7hoaT%DkX&_KjU6}bmrv>4hyU@05aI}3j>??wN=-&+QV>R zpYRJ^5h!6SB6cua1YtEm~509^_iDiM(2GK~9t(aM0_ecLVQ;WQ5vDX9Wkd?Ke zvh^)K{Cb689pH>R(4UlYCE?}#MYw*M%xJBN*l6imFVQHjOeI*qJCBj$efWM4T2_g4 zO@eIVr0UC2!w;@{YSFPpiG#>>f}hsdUionbh;5KC?-7($1MC7n9eU#^KEWcz=zfr4 zd;mr2K7e>%@Un&E!%-$mUZa91QeYiPs57!(ed6|s?gwyT9Z{3A?s$80l*phX*wf9-i0r+vFM_p5l|jtf5c;X`(AglZ#VXi za*XTT4$NE86R1$)Gbq5#vW#avy8$sx4Pc~qML|+2uu)|Xra-n11BMX zW~T_(CYjF(6D-X0oA!kv!j!jU+IFvIQ%z0rpGI4!Oz$48wo;^f80!nAGwRAEJEolq z?`C;#VZKX?l&*JBL9%;G#rUyIV6i5YYmR?{AqBg6vBgL$q#V>#@J@k@bTq-ASl$`SwUgZq1RC$XgbDXLvo35kmG0=KihQo1~>5-L& zL8S{D#Vi2LQl9T&Y4jgqU!mm%3osP~fwUT*kzTIjjw=)^dQ=>ZQR0VX5!;CLBbsW& zX}+~;yDSbZW|pl9wuKH7z`LncBQf@MUmYogj@V1Du1IIg3P&YyBsfXR6s(xZ9tyxN z$mh2SmmBxIjZs z1UcS!pI}o;O#Wjnh7qT_pPo;HM=YU_kjCb6f6*No@fJM>NV_6JP5?`Md)tN>@gwvb z`*icyW@ZFg5bzouru*KY!LBBO+Tc49nmYaK$#-mEv_eKhAZ-FU?R5Cp>XAMb4*SZh z&}A&5Cpb4Im7X7gGE>xUz*pSiomrX~i->u)3tBb0zD5QcHZ<(>xvhj*CFvL=JYzGV ztG7?`%BjU@!%Zl7JJhiFePix6PPWPH3SjiwWTa*UjTYj@OuvT-EUc(j{FL!@1I<6ru~;FMG)^PhH5oT8n=Hsx zRPT`SqNoXC&3nRlQ`v#zt5Ewf&1Vz3kve`&&YcVES<1&&yH;kB{f+9-Px~dcZ6tik z)1wdCFhJ@9=qmv_Psnl?X?37iW`u+S7nU5>d+CCF(3UZQC7k5VIJMAJks0c9U{6)T zHUUqdgioTs8&NFqN`VKrkp}Mmo7Mm5x^+ zwUZ*-i5K)vW+QL;T)Fzk*Vq;PJMQ8F@BFroKLGCEz~{ax0{(ih%i?JghY4S(b!4j4z+%D0g;LhZ%0FKQ9V!O+nV<`}okF@Y7x$x z-^NREka%gl9Kos|5J>gX;xKSiQ-u*;TH!_9USw+Fcq6UITY+@-tB|vM-4=>A|65fo zw-tao7Xfb49-mZ}eskC}y}{+f`|Wiqf-ZX1gjxOklRC4W*Qal=qmJ^M@!X^OBkvA-5<1 z4a(XuIf0lp`xXmt;tSJ9^XJ+v*2d5m9~w`VmE` z#%PS5;El$J8>W zmZHLJZ~EOvJNUT=hgDXP7uV}Ucy4o$e3X zl?{b?xXf$7Q;sktYlEc{5-`rqA>(*5ad_RQdE@xmAqI}{LoBw(FRA&Z0L#NLs%b;~NuAL6==a-j=>K*AfuJwMxqr{! zzyC=74L-jAR``!%xkpe~0{REG>=AtYUN;CNpnU?5aN2SL1z$)Tjre-gQG(!M*Tc>{ zRX#9y$~3@Ty=^klQhdm(vzk%WD}wj@z9_%7v#@`oNDGAc6aa%riFC|L2W z^c=ScLjo@+SnHNxeE})U*6jnLH|cPty1IH5Te`zFx+C{yQR;Z!Z4Wt#2ff zGOYxZnC=a-2=SfYnSFFTiX5)ehQ7rBl4l7*#*rj$vwvW|ri9uE<_7sH^oxBFMPru+ zZ#ZiIDc*#>m3JD|!8R1s{y=xU^?d&iNA_ zL|Bcw&j*OW!?qIvvM7sz9TswBvN6oPu2NlJ_}>g4idI?n!;d0R4AI08RKAi0(prL+ z(B#C3x?nEQs}Q8)T+lr(e^y@$P$E&UgJtF_*)CvPiB3k4B{c4&8kfonu;5rPGejc7 z8_v{xM)tdJr0(M=wuoLxAPQL~bG}z-!|kyn5E%{sbV<_f;E#86uZAFaOV8dwHMBxN z?LKX48c@+2#9Eq5|0pKq#U>(OM%+;L6H8^eED<&=XS{XU8bVMxX%c29Ofj zcr+j6U7-|hFFfZK*Zp}mlb$3iw2T>ohJY{rTCo6&3GT=@;oErT1<`&TzM)T3{Wn{e zowneTH;g+!F7`~5b9 zN9WhhYyjtmGm&^`LZ)qI@T~~-P26~KJxz?3iM!Y5uOEtixZw#xbw4uwisk(JNmHdv zmai!2#sm?@1Rm4sZ1^3cZcW@4s6sC|7@fGhAMqA5!McVZnmu+E?hC+x!?ZX5;$6$r zN*M0}I&D*Dq=|^*FxhQHy0A^H6OW_au%uMu;UkVGjhMIwT6nVaqBRI1jGB6TJz@g^ zd9tmsf&$vBudnQW(8~y5O6)FPRaY+V97O5y#3u%{U`W-1;Me`V$MK!p)4se`0U?5T zA3&YXq_I0Qg1OIFOyU61wve3+LO0a12Bh_*TP*iv3PxD*x`0(u^1Q|PVWwY#Wj0zJ zY4^1_Hn*!E1f#oO-|#Thsr!E*iaMI%CO~xshX&$>G$m_uzcY^#B8M}MvvM%+p$x=< zmuc&>kr^GRfOLv?FmW=Y=n+WJ6XY8_Sx*bS@&M(1kq& zf(CC-)x(>}5znR(iCXe5=s7CCKvBP~=bkCK$m=hwT z!Re?C8$Xdq3Z=Vsjr!?p=~rK(bef|4tDhgLD1o#B^D2?#;yQ7iN$C@K93-UZBBVZZ zBT&O)!0K`YI`MaRll^dY{~V4MwA;o-ZKR^ZBcRa%NLzL>>PX<>Acw+h5{mS+hsM_F z3~?c72V2FI(^BdNJ#}$?;iiYeVFKX&v8#wfdq;IY`!BCl3T|`GJ0$`U!*(gieGYv8sy5kjY?Dg9{gthKC=9m}7qm0cW9?BWn#X zqDH6C1V8=uR#OKq@wJ=PfO@|kEN2yhN=TZdQ4q;ObU4afyPAzWl7;75E|Qq%z=QCA4=6B1Oo5PVFQwXa5T zl%O>dx1f=VGD+7b5gG3wf$00axHI@s5rHroUEwGh)$6C=`8xQGq3&laE-7>qa7w;} zTrd-FqTi8nr|@aQ1`6~FQr@CyljB>mIl7=9j+!J5AeSN_b6B~Ca1$ye+8~TYA`{sg z0I^7^0~;E;ybiGafl4mEL?H#bN*bv1PrE$VCumF~@V^ngc+Bu&!hwcCaL@|@=I^XD zXujH?FbEn1apXf|AZ!WBe;oB=KY*q*l&8r7Jk1iy%LmZ=c{79huFdKn)d1G`!5BO+ z#|(d7`W-mZf9F)c*73^m0c7)`W3XL#Q~z|lbO8pPLc_3lH`#u5vcLlRZvc|*2ft4x zA$eQG{E}V^U95TDyOT-(#rr0Bf30rwd@%|3KoVzp`6mASHorvd%*AeVHL9uVQm0bMC-%ASl^Pl7ZTR4a$H|5tzwTMSDjaka%qrQ#! zu*SL->ZPA&*$*!DvV>oWMWA>T$^PWIX_w70ae4Is8k%%Y6D&y_2pb%obZs6nI}Vm+ zy$;e>mgxGZX4ce{!cDWrYlgM8O9(pKzjl_DmC{+5jxm`RdS_YDiu8h7?L%ZResC$} z3IEp%z^){U!Z*LW^1ZLTGf3(5T-9zYBV+$%>lIHy&Q7`$Cjau0)?j|+^zwAS30e3^ z+XIN1XQwG#|G8QtA&NUOMW>RKD%;PU-FemL?}2{FINqasL$1S@t$wEpO1rk&KK%O^ zKR2t_XC&y3QfDvA%{gNh`eTiLbm=tGDjT5p0WBD&#;u5X=+Xhr9WZWghKaT_7N^20 z4Si|n^U1028(0%6U*jGv$KAD(7Q-e)pj%3&#=_V?ks3Q|NS2sLb7%(bRAotxYkjil zO)P&RaI*hq*F&rHqV^M?KmiMU7K^v?MbmjXV52YMyOmW{ZD7(2kd=*Ac~)WEzA_NI z05U~ZkDQHWe){LO3vL{UJ%HBGu0;kn^7T9qlv~6utYI1O7t93f`%8O`jNW1$dgZ*g z3rFU0IEe}SZs`I~R23D3u0oyN^STQ?kupPpzfTAMjE-o9zJ6vue|v{r6fq91wTfax zjy7n60DG7_NOS7qbzRCEf3+^+RHkp=-T*B+vD9~zr895LaUGJJ2xZ@#Yx*`q4)SSN zw!Ui|b(c1mHzVen62e9n_4T$KeDFh*Vs*z|fK4oFfG>A27po19;YF2QKoMFZBbn1a z{KaG#xbi%EcY$5`l|8O+XEiAJ0R$7;k-D4Y^>#^&HyeYdJ<%I#uT0s5bVd4~S`%=?>LdG~qQQZ}Lz^d}hyB^4Fo~{Rt?I!71na4@SwtD@Uhi z$(O)~;BWN&&`~96*rRaXOcCZb=EHd6?WW(p*r9Oj>Xki7ykH|HJ?#5CH)J z00RI40|WyB00000009vp05L&PFkx{ZK!1Ue5TUX0|Jncu0RaF3KM?&-vWkAu6(1m6 zv}Ae=uF?{SEh#|%0C+tyk>J73O0hEA9+Gzz6o``Je$)l{%Yqq0IQ;`6l|xZb!lfd4yLyXJsCv4% z0)ezL1#7A*#nc!uKn@{fl7EL9Fo|>$R|4`)A&Rol8X_u(sSw(6bVu@zi(Kra#-Q)) z4^I2n2S?zb9tNw>6;uIG9 z0?M&gsFIKh8`P3%M;#q7ffWQEcVUTW0Rk@IHxMzfgO${&VOo+y^nasNuP*-pl}9&X zl*HeMxsev<+2T8I0Y-kl~@!3dK4j=xOybnLxPG`GZTUA-R0^z4F3QeCtPl* zI^H`c4iwXP)lU=>nSX(~(eAIv)OQXvIXPt=Ex-+K)UT+Nx5m&%W@ zzBCvSu=LF*-p)#d*w6MK^y1GoaBKjJ(FyUxNo?^7wM7W4p^C`Uxmk7{Y)S&~My!;n zL%a%s2qK&`Q`9GA<1;tn@*AUp5DH7EhfB%G(hbT_K@gfQKyPm!fIt8MU}2ok(Sijl z3V^696c>rLrGF)`P>3RZlZXSjQCOT2;uommbeb__(CHn8Tp%Ypf5`{C0n7>|n?;BQ zu)%4#1v(C6XF7}^U>O^XhY~Zf$YC&n7LC2VeR4QmT&HhyXS`~l%)C9iCBKj(@{SJJ=%=-h z&kxT9yQQ+DVOP+|IaBvC?7rx9N!TL9)k0c=n*w@$T1<(@!VL~7jqZHt) zqMV4d9H<|Ic$wr;OeoTXF_&O-M;BdQ6-Exl^cAF!38$q92!$AP=EAx7PK;mip2;tH77n}30c;D)ISTy$Du+XS|tM_iQ~hFn~QPLssp9N1GKm8 zU&8{+8Pm!;*pYF>(03|$KMBF>4U0$vR5-GP5rtyN3Yr1Ea4{(w(ok0kMWsNeCrXO7 zQ$Z0jEnb_27z1S2B5xb4u&|VXMiGfmAHQ;7cz*)15G?7(z+h^?WJE-`k@`&Jc)}vk za4v~zac+$VhDoqUfxC?|YG8zibEiR4A&g2mNklIPF{M_F5}uuGM7SH=;>Ip?2qNH6 zY0QxVq$?R|C3<=x;6>gPHY)NgoLE-VgqWce^JRh_5Y)iCBq5x)X?RqZXD>kI;OPmB z&|Dx^yj~cZbVdnVDd82NUYf-c!q_&Vm2IO&RRARzg-KDVr2_D>Cf+qym03rOHGiT16 zxp(eA_j{h()zy2|TJKu5N_Op%-J#?`=Tvf9fVeuX9i& z%RkcpxXQ>WsZs)*Y`kn>4iF~?7%T|=a)K$ryn@^Uf;_w+5JEf%galQA=imf$& zZhx2qfn=dqXubU7nq&-9495bpZ9;+_g#pSPPmYgR={h=6Ohf2cPKH3Jh)Isa3ZtVA zqB91a?TLicuP14;LCi13&++QQzoc#tqLo0TSDKW_2a+B)Tmh!N<}OF4aN;bP#~8;A za!A?>yZSM6WB%UVXP2A&wfFop3=90zYPcp^(jf96&9d>zrK7r6{MHYGQMb93p8+6m zS1^*Hl0PJpF>#a>i@R~SVX%nh#2IwU4XZ;TrPI{iJZF4)R9ny0NPdW;4EDAm}qIVIR6{MHBA_5_F5r zs4kuLuQOX8EuTmw+ukC#X&ohpoYH@rI?h$1=@mqEGHhX zw?Tr(Gy)=*y86wKmD2hZhdE%>qG&0bFxhxSh7bf~t5-$1gxJ&o1i#)`u|8uZ(n`$q z)L*$TnJwzPm^!k*_G3O~^x!XDwbpFSpj-{&7nqz{ijwi6x5lTl}zdS495qa@w-Yf5|F1M-YNN)gJBlz!uglE8ef@9+NEcFl6dWpC;BrkijpXdADzGw|K=zNW5 z!1A`QBYB6OdD+9M({_bz(Mv=jC1RB z_0TO*6xOwX?P88QKmKbq09!4tI_1}%r&qI%z}9FUZdQiF2Pq4{7Wc*1+8U!hvup<% zfk}Z0B>At0N$sEYlS)LjXD^k%Ms9DQwUMKRn`P1ynkPsw-oMp4z6#)+L7{}q9-Bmp zX-!k95|R(WWp+vlVV2RhYfQx=Q$a#aLTDrDzJ?cw{X~$y;dPxn_-qFU^;U$K!{D+C zSKUU0&`tifJxLU(@jyR{XG;@KD8I8X+=_LH&Gw)OQ?V^)-%Ax2I6J>|>GX^BhKoM_ z{u&*Rjc;v%$2hI;M-!?Lq_c;@dg?q3wW%X%5M_AefYfV9O^)uUwdNUtFG>Q@tXDc8 z@&%G`E&rsGM7H)dt}Di>^zGhHy|6`B$vo3vz2lwOC90Nyu+!?SQ<;dMCRhnShqiaN zua`9u7yTKXSFF#L$op0iyTZae2B!0tb^zZ}qGFBbYipKDDmR;^JRj$o%1d{b-pkYD zagDk+xUA$+?y^*#^-OX@wEk^FJx?Gbl4dolwG{Q7ldT_(p|u69<4d8{dlABg`_j)q z4oj=UaGNz4Mum~q6H0uA4NuEYTPW~JVfd`5^rxLY^prN{BJ@dJ_=F<#FGW#>y#>7; zoE>b?%y+!Cv%$-av1fleAHak5O;S^yI;Q$l5^lz&Ji zE!@mpZJgb0oE-l|nVUgDVT6E!BQ;x20AU}9y#ZtHAeCCvUWF4PYD z?`-~N=>Nw!)MpWTYj<~NL3VaWH?}`KW;1hgVD~n0W(Tuz{Pr4}DHJ9MNVr;kivKPerK+`)yOW!>lQX4+1{bB0iJ6V#@3b)cZ~Wik|FrQ3&qC6~-9q&Dqmh+^ z5Bk*va|?0^2y*i?b8rfBaQsdCAILwIp9v~CncG+N5BISC0>S0_swdkaxFb(vR`a#9k!U;$oUR!%nX z-}V27`!@rCpsbA>w4?Zl{vG{K-v6vd=O=_xJ8o4-GxHzDO_&A^u3g-Kr=7Gwgk&By)Q;3sK zii7(Vrvx9bgcK)OTteWLl#~GPD}Jz~q=10LE8ag{{x{eD&+Q2O+u-x(C}shjj{oO| z{4M!gL;g+*s@iy4*y~E#K<_NKKiX1&7uuZvQS`gw{w2_Z>b&_=i66}Wcb(s&KP7+Z z{BdV!IGdPR{O&G)nE4y)FWA2oyeXDA7e@O}^uI~|bCvvKIEIe;(D9%Bucrhw@zNWb zv7LoGrG+E(N#gqZDbhsGu=fo32*N@J*@B+nV6Z^2STJx{Fi-s;a_9sd78VW$=J!1P zA0K2C1SEJwR5;k*a(+w@EDRhx3>+LRJUk-Y?^O;g96SOR2oVXJ3K<8COD&GV`AW^i z1rMKwi&otpQg|7)C;!GT|)}FB#`lGHOvxwpH@K6pzgq8 z!C`@3f_lcO(Em#S{u|)$a?H2MB7rVtWIB_rHG_>$AeT}z(!8zVupC;W6t3Yz5{d~h zDEODCY$Xw5uo^Ri?0b8HL zQMJw|&}d7lkopKsQEsKCN7bNY3_fhDCMBa|7aDfuKtk_kY58gBvB`D#5b;^&+1$Cx zCd*>W8Uy9%DudnH>){{mCD1(-Mm>TYOFep;i4oW9y* zVp==Jni<8_qLxDXwpC%=z#>K}|Mjd`a!3d|9G>g*zxzeWH0iu`fTc8 zMEZ!ylyr*DWL+-0n9iY8w!IT!wB%6y0Nq+8lW#j$M~Nt3*0b-2wM2g$Xnq(cr5=cugYfGvmiClDd=TCnwH^l}YPR6A{=9PL`72W?UX z?V;S<5WSIUxDE1O0r(gBU#VJUug6zMhFNNb|+vbeGhlMXilLG)* zH;i)D{%bD-rOYfdH81)5VVgr5C-*@`7mcQgt-T|yY)ma0+E#7rKL4W1ERhyCb`gwW zyPnfxA?YvnBw>spqZvGv!~>X(SgALb&RF)TZLh_xrVN=A%sN(vo`=Q%mjL`90FkNJ zE#Ha0D6FcET0t9_z`rQl`c`0aCevRl90!gQX(+9$y~WuEsP2Y}h3Jny@*B!` zxlflU&;**Ay&cVD+*}dVr&1}UP$FkJmS?RI8(9`5l#Uj>FNO=}RI`yDg@zkmJ~>;g6RY8v{q1^7H4tM&DHM5tqAY|7Gzu=)ru zPJkm^i|k6+-nnn146lp)`qucg;#6?2qB2_PEu{7*)`ISm3PbwgV9TFC>WJ$)Up@1P zkC?s>NSSYp6U#Zv(kJH$fQz}9zL+=MqE?2CeT4O>0e&&3FKa*YEPgesA3x$c(%dhE z)X;?zxYDI*U73DW?!Z8%jR&=3SwJc(53MUI*)i!@{o@DcF=g04B*<$LNF`Nzp}NAk zzxilYKEVSafHZtM#KT)spi9Qa3m?&QDq!{;k)4ama&Y@#&U3@kCoA#q0r)E^W#_wd zZT%BaV#jhRF;be;B4^X8Z{3tYi+*sae{32qyAAvre`6>+{H?N|nbVsfFwki88`)~3 z^7tS_A7+igh{mn8XJ+4q$%KBcfsnOhK}zLp@DI_uNrOW4oK~2y3>Vhk_T*;w=DR)t zav|`K7*4d-=t>~WV@_i>V?5bkMVvfzhnzK`PqVbpv&>v!X)I989nRB^v?zgBNS0gFKP8FkkFW(s^ZAWrbZQ zMM$iCQDXTNMg(o5itqgJ-2YN~S&84}m#*Gde%){uyHvhu&@$OXf^73e(Juw>d2umg z9+7czqCZm=_>9N2u)v(n@(Dy=EbDC)R(6iAJ~JDmTs>qiCdRh$>We`JCHi*4A@42Q z{qB01YYl0X{pb*at9#TKgAPR}`KUF`QdXm#^|u8X%~et~^$K02;_X3ey>N<+R? z_V}h33f$FiBG?TYjT;w191n(ER|Hj2yd6jB%iU@N~?=KJGm z+A1ZZ#nVjJmDQp&ogvP@w^=?H8MIe%tED-qD5rsc0?iT!F3tYX`8?XxB=^mxKeYZw zZrN?Kw^xjJw-=69W<(~#jJ~Vn-rl28}As4d|$OKx)PU1Ts zG_TGWd=6;`eWl-(Paw)!Y8sKHn_H0=mOz{Q*FLQwSMRl<#LKo4P}b{=m;?g5hPl&v zP{S^IwB#?=>j9U2kv*CE8|9IdZ)Hi%*KTVeEh0<_C(7qFLJY{ry$^YQDbXC3#EY^6 zW4Fjv=iPl7tCttCG5sQm3cyF}%-Yg&ScH%WujdELQl?+Bawb$zL;7_IH$IddHpO`t z>alw1-?Y7 z$D$(nBd)&N6kbvFT9|7tn3q&Is}VlbX{nGJ#GkDipXKnZm`^d+!=FF}M1eB)Aw0yj+dkX}Oh1Od<==e= zf@mfL-(%cp=|xax@kcOuj-d&VxuWR;70GwH8{DiqhiKlOFXHUf2YZ6)+=hWiEc#z9 zmBOaV!=he4-zk>dFw**IJqy6ZsN!4T^rV<0Rtcb47cxdw?*l$GYvttVT!!yPBfcCa z!g)A6s^uU!FU^%PN}x!O_MI3&R5w>yjCZf6A5V#ZdyK&scW4v$DlY4o0{hyQCwixc z6c&3Wl=jGYw!>?UMhHGl2ra0O9j%(dt!H%3DnA_kG_LOrQL8rXT&rIYF@`{reP% z&yrqBZSIm1EfMA|`2M_VLubVyXP}e1@Ki&OBR0pjN*#sbLP0uVq#T5hF_xnyZh3}6 z%)L#1*Q@Z#sb#!}#wj;p(V$^BzsxrCup|}tbz_0C2qDmjPx}>@ZYdtISSmt;=ovV( zSazSGYnrgJ%4}GhN)GfByDH%TZbuzVDt_#irZ${lD#IL=MP9hSuIS>p=u3i5x|p@6 z$nB0hY@S6;t<=@naS&4K>i%t|0l|#M9!#$FQy~k*2W?YE|76B2^-XVwh$J|Gr zcTc4%1OhW*t*%;w^0%dG2yEh?$dB604z?2MUkQiN5G^VVV7VLQm7Q*p0dPkWIJGc` zsInTLW4#ugP`?WI3Dz%#$6frYz$-PLvTxJQ6IH8?o&X+DosoXg>-G#xUO+e9j&XRN zKz-X53+0=2?yMK57z=bypfV>t5{*TE2>{YDQ;A7hS*@#zvL}Q)(Au^Nl@}gn>5waygJuPkFa>l)i@{De5hUCtMUmtOPK7mpzEK%8f0BI#M+L89E7pk2? z2x{}`OBfO0eXE>y-<=7PiLIX6^+ckq<$^NW1Qj_VYKB2M(3ezy6MbnDmtRC_ea&)m?CLz8zkJ@WID}=V-oHaF&U(XIH6>GewKEEPF?d`yxZnK> zomy7(P@dqT`z5y07O-?jCCI*_7 zKg6>y;U=1RNmlH(FEgkcBuq_RP2PjagYMQ!H5GUpXL$yoe%{J_v#k>m+eBe{e8+TD zHEeh|5V7npxw`3Wuo>Bg9qa4E{La;g&3T)#L_kDaK8`{~F$hoc8M(#@nDyiVMT+}d zlZ*!?SwPX`HsHgs%O`kUKXuMgP?FAkt_@GupJCNs)iN^r1ZwAri2@U;nuDI@?8vou zy*PSJn82PiXGiU?*$=dlrr0hgWjU; zp~Ui;qajBLH)9ln`Jqngx3Ro?o#uL6SE)ZFO|k^MXqyvBlR{3$cHx2y;6O07tupe> zGyN?~7(m8;AieJSAYYW?K;@fKr{?{eue@3$)|Iet+gm0Y2p*L#5O+FsX3u1ZXUA^s zn$hSCr*n(N_TEEk@|rP=M|ML|f|HgT2(}gF?f8r0Q7j2)-Gx6u&T;a|OV^QtIOJ_O zSMh3$vGp3?)qYRw8h!?$hS{e`YH>uTyVc>io9NGZwgUaSp{xR& zv;g(~a^}Wl{NcJvXKRaplPG*qe-*2afuN<*G#94PIiKV&wayd!)XC{CK41H!q^enW z$MQ^(GCBH$J^ey?uC@7GUq^($xk-qf2_z{UyImp&MQel5XvFq~J*GO5D ztgrncsk4qo?a8tqFF&umI3q1uFr@3d;5P-5 z*P<83c`ZNg5v(1Bv)ip^VY!-)2JrZLs&y&3aD{FdB2<0&uCAI>#5j1&(9{bU$EIXG68Rq`M5Uuh!)zFvVKc+$z2ALQ*N~Dcmz)xvn1MrLw2%}@>r$|q zv_<-E$PpU$7%mfL2$2wWhVIHw+`wy5U|Yng_b0nVZzCFAFq%(N?kDImal=1?{(pJ>&|77;qs#p+r#_NGi(|h zNzC2o488sUX%+S35WI%a_oJ*FSfd&bm$7N9h|DkIkTTdDqLdn`a~;wFu+t5~ct9<# zAwJeu9d+Gv&3~?j1D}PHtg;AE4Elegu=k4FNVp(=^*b;hT$65IlNiM^LSni`6xyn0 zP{PlYc=&G(VMRH&DS^yXq)!V?xEU!d0(oC~9%-6A$==_6zH) z@(h3QnIfD0o%Dp5qzKV?6D%)NtA&sytjG%xuCX_qdiq(m6r0m^aKjH5%N7a2jayDn zafiZq3Wq^yrWa(Xq*!4_9cuBCl0L+TW6G~ZvvS{F3IpBUYSZ2-O-w(%Bi?zb$9C<+ z3_;IPt>5gZx+@`1JWo){t{D@m%65DL?W3P+ZSZ%Z$}==J&kZk)3}%QbBUzcgu7?Ao z=1VX)AeQr)eHB2yrg-k$MgLsxq#{k@SYUoFCW`#60YhU{7Fwe6Bbg?}vNBT17D$pb_GmW74_n9TL1Y9gsfw{Di=Yi1hJ4 zizP1p?BZvc#qkvdcK7#Kl6nI&OmJOn5=GvsGD37XsQn;^+@&e`qFsADdQtwyqz>4tM z)O}2BJuCiJD%sA}BaYcODfg7`4k3bN=~{@Xtp(CB~U6W7V`bF!wWa7&hx#D>Ah-LRUfbP1E?{DWFMnvz0{|I;H?b z|A^k2flfIFcYwXuwZfmL+#3WjsjWU5a<(PUiaLL2w{lc|2m+AtSp?S5#lg?Lg+faA z&@!q!kfNN_`R^;F#Z$Til3!k8?S(%^il6DZ!EgDaH`rukZtLe^9RkVZT`E-jk{)av zI|!djEYi$)-oJB7cF1`zeiG+>hu8vStUGl&MIVFFvac38Uvrx>P@ff5vcHSuc}Nd> zX&VBRcA_g;h!U}Qg{q0wlQ?{q-d&A{9|DyYRSeRT2A2wvG;FGSSqj1gng`>J9unm2 z?#a&D9sjT>Sv^xY5r_|_Q>Z&)DY_}VV`>-Rx9kHEq?}XWKY_3wD#YF)-k$(NDMOS3 zXjxyC^I_J#Y2HesD&F;T&#`?iU z!jXrLHtRwHeDwB9=rjKaMA{EHq28i%4_WUn;9hMMuc0?1Ih%5=0JzNyzbF`fNJUTf z4{51zAf%sq(wKG}v6)|MnJ!@g3f9Z}%oATg<;cq=!QEUek%8M7XH*m7&!)4<9KeAr zNpkMV5}*~eex2;la`fK+rl*8(gblTb1LWKxDZtMU!s40TyiH@`zB0*kq{;iU^C7!#Y(vP zqzWbaB-Qn@^g%Z}?#a?D%JzhQqOomv^qr^^5Z&v`A+1gYl52@%bicxUTWfwg?d^t+ zT@o>V%+YqC(b^vq!2LZCVI_@Nv7bJZNP4Cui>K#A@%8Bs)vt)l!Dhg7!}k?NI*?qZ z(mdNEIjZ)P8w?T;sxJkU!1W7LVeX&-oDwEzez7UN&Lp(I`X{Eu?Zvm zn}56J1GtCp#Zkl2p@ir9B9qW-WaCobmGCcd4L`_=nK%JB0m08+$ZSeoreqTDRq6sy za)`|UCPp_$h^kT{18|!6=BGrHV)K;_Q`E7c%5*`CMsrl&{)^-RH1=ybl6$yn#`NUo zRU?(TB|!IUn~Yi8KtL4rh&sDZG|Ux4dp5-jwFB#Lzr|~j51K@jJL63X@&)xN1OglR z&;uea+&5Dik0L`9n>)H1V6&6B+@4vLYWFebdK6|eJ&~-+Q9w0_;+oel?uPLyyr=Ea z@>2Oe?Q22LoO*T5yRsH>(q||?nFqVie+9KK(C8=y0_^KgplX$`ek`9jw$eXqLymc( z;?xOuDIO&~{Nl*mO7AVIy7gL9riPL9>}4-x z`(dF}lIAEolL~Ai2>oQI(7XyAQR8`$g9W{G2`FdYQ^7QfmBh^_cMq z#O!};y}M$1V5Ab-ih#hjw|b>{QmRK=R(_A|7DTI~UFIB!WGr*BIQ3$9QSGwz8~&jI zXdCG8EXy9_hDqzX91+JMm{_=s8n_xCJDzwzeeb{UDjB%txKLXhzz=9pdoG0gnU=Xi z2b8GOY6uY`LZJP12S;{yKa|*Z8`^WRL3vfFK=+nPDA^$xlBSNDRC4Txw&e6^-p@r?KbnFp-VL z;bD=f&%xF=2bGhmhs`7*eqHFLn%)++QxIRvm?0p~lQ5A!L)ZHpw;SVPqA_xA&Zg$W zPnx}wU(y2pBPTdp0BHf^8m6?uT$>3*$Mo=>({X9HOY&}>ytA|NT4@NA;NyB%9oe=V zqxHUB0K74K+b30UqFHW(-#a!mxnt%+E7b1yN1q!kMPNpHD!u^*HuxOm(*_Ytx~cNM zKH2~+S9uOk%teSCj75~Cf8i~RTZ$#bYZfT-+xXrg)h3|5u4cFDzLfOMR>=ATv z`{R{OuBs@YHOMFZt@5v5T0h|-z{??UBR2q2xbj$7m?m19#=qZ^Jwpnv7%5jXY4!=T zU&tI-gPz?bI?4$q%FRJf@J@+~a&Jmz{p^LhS5_4^wgg{I4=u5`QD@Ma!+)_zktL$U zc9w zv@nk)EQ_HE7JXel-kox7rQcdSg2d*_iP~-x^2W7?^At@4+-uw7_*#FkEE*T$(3f{e zJM{|}D$h;5OFAz`IycO493cuxfGW`4fKajhlOCyOSxW)SPF}SYvYh)7D=;*CW190v z=ewIH5PrWUq@c26IM|{l{rRrh$e<{Vy8jpRJeZ7n@YI z(mU5~>g|ze>*b#e=)LG4q9}acYveC75P;h}sKY-<)Ts+=_^uN$1#Gr{@m&ITO7v|C zqD~S7h3YMgQjxkj^5dml4ro$5gZt)rH1JDbrNOwS5Z5=*V)9XLnz&-HnpHRxV9ORJ z&s~PU(ygbL;saqSe~j+TVV#tn>3sEEarLUnp{7X1y`;1}{zG_ksPirY>xrPna}d$9 z9yrwxY>uqTrBZ@Qs)M`cG_UJ`jI_Rd(td>o=*l-+rNYKuQ$csCBF78pLp{uz~O86sP_-zoKlA9|<>4p@rBQ!c4#nW41vv;iE0+}7< zR+(>1WbP*^@`i4I&;1g%(zi**smScdnF@CKUb^vK_k@t1Y^&`<{Q4p5P^96YmPx=| z7{p{``_RmCVZBSwHC|r#`9Y}04Yw+5I^S)Ab~?weBk5M`{dwh0zESpXSa+IhG;d$W zEsx)J1{B8uH%YvkAeu{)M`0g3-+8o!Cs5mzK}m(QhwqHe zpJN>tC`j@)2k~+sCWO9Tn&Bt7daugU*rzO;7_QRJA{d){6del0A|F;=3{Nzw!Uk|% z)cQ36G$fpc$xl>sJ3B~va_`D-B1AI-8{qpp1FACt;A+MqzqMQ6%&45ElLEGE2Rq94 z{^izs?kP{J4Rl}p5HDtEt&aR9m2IY$#-k>>CEc^;CI9RDA<=$6#X-B!+=79gn`7pE zi+A+fin3@tv8$~%G|k8Ng1*FrVKgDBv|)4*{kXW9~(gIskumlNx8@UYnyk@}(8Yb%FCb`F%vO2qq zNq47|O6VL-brWkJxUoX}SwCqpL(6CMlAx}u4KT(?7Ax?grZ8a03J{L;* z5c|cOyXXiaGm(bddZRV5c?z4f@-gmBwK7&;yRvGF=j1g^)wJVH4@jMUqM<(#H`tL` zOMX)3gzkK66X3IQU<^$2I3ZDLblewBX_-6Y^N@6HOPJfGO=zH?dX3Qs8#q=W=gS3i zDE4?F>Y@bT!{`}u$tH0cx11y{4@5Ab&yk^M6F?_3h`1lJn!Z<6u(%t9%<_zM)8@ve z3VtQF*-lU>8XydusLpR}3_%SM|2U;t0NuhCLW|0iRs3j2SEOXa1zSLsK5AV4aVgA{ zr5GCqJ_My9E|i}6DDnwZ!!49L{_=;tl894_a~S}XMqp_B)E}&BiK6j>L{Sz$X|Qbk zQQA7$TC3-a@}s$u)7e33sXIbsuS$l(6Ut$C!$%GHCRV%j57v3@2D)oUc6y>Zs)n{Vb8OStB@HOR=h*!tn-X9XP;|RCoC6$`%F0Bq{Z@VITX@0 zP&ad)`hn-o?6U-F8)q)OlT8kFd%{|SPPxpj&cN z@;Iqt#fC-q81CSzOuNrKNLIJ`$UDe{i{Ph~2s(HpNq;DZFN_>0;We?Qe|2MxEd*zo z1s>-dy$HzgZGF9vRo3lr?K~I$a8L)>WYsy=MDOtP{}3!SyvjI=4?un?JKWFVTP?b2 zlfwDNZdH)evZc=cW;(Wr1?mx&9KXu z5~us=;&{u|+BD0=S;)Vj zOTnw^nkLaN%~7EcMyuK~o+%WYqEJ8iOZvd+O1Lyj5W|HmaXTra z4wdE83rihJ8ZU>G<(Ooumh^^cG1{d3cfwgshinH{8YgH9oBc5;t*q{dRcm4j^o4jxF+^r-K z>wm`RfkYN&H^%M=l-*%PHKKui3`Dg(hc$uP=bIL*@2&-Es#P5(XZVG ztPP(yHAd!B;y+e!Cp=V;5D|%a#+_=3$_IQjo2dD4m3E9NU@Nxh=U>DdhfrBVDVqZR zj1VMgNhdy{Yk_!3(knYf_Pk-Ja(6vXN=L$`W(_~xuGZxfD?GN2SmWG~>VS#2YL5-R zG(456MEtAW?_1?!Vk`ig=bT+Y&BD*Z-4BS3rosbB9Mz4VmT2%^g`AKETQKn$FEH`& zMz>fVH>@7dItqnOM4~AP*TK&x&Pm71(7F+lkxLs_jxvBb6OrR_XXuJ)6MgjO4weVPRpW_fu6N2k= z$CmO@&7pUH&L)1J^tz|f!Mp1<(BJGvTxjjrD1521WwH@jjIFyC*Ll(yQI%fNPHA{N zOlsD_mde~iVPx$%uYDi&1e(v(FKZ^zY<6|zu1+WbC+O-`D0I>uP@GV{*;vWzvGP=Z zM`GC(YY3UlQtSJ)k_t@nby{N@|xH>zfMYO}N)8d)9HpKO@ z$%}R|t0gRRlKaXlVAFIBP}_i?>&+#$#W=f370@cdBR~C#GZ;+ePA&V{gZ!Y?;In-{ zOlOsmf}ZQf{l+g6G`O}!KU1~sanVp{1R1@dvF4SMzPh?mp&u-5!1IiMK^ewI5PDrH=SAUGiaK+DT z!c>4Y9XDVcHZz%D;_E)#J|1mxI@A2@QyB(HoNI}bY?+yDgh8pa^|BQQoAxoh!bP5_ z+?z8p&)`>w6n10dDsEyq)e0vTf)VRHscF=|`BbK&DT(Q>- zL9hDSw~l03b{1t*_i|*M&V$E6c*NCWCN+a@Fey_Hd%?8@nA{x0xUCesLDpz8;LWUe z4!dp?w<)T!p@*o-G*z7NF42|feh;mM*RzP#BUqPEQB4?+Jx3{w*#R|$fa#{k#9)@p z;caTwD%uw%%-Y)U53qeFSPIH*q9M=s8%oBEcSlZ_Qna-Mkh2(bUthGR9#zYtnyQ}K zp$BX{Hq~kaUjs($UT>FHHasvrxQJ9jatxk{+mEu?g+2>BN|*zaFb?Y_$4VxBOL;#R z9%Q1GxT`T<q-7qNOI)W- zHc~*KGg#}&3}wnaXQK~KzI`2>d=ujvpZ+T4O7OuRa3~29sP?@iop^9hQ{1+FRa>CF zpDNkFfs(VMn96uLO4;fOd+W$$uIh>g?pp1>4`_50X@}T&E?g+gY2iP3AP^6^^QH!q z6UqzD@bo-NM`I$|%DHxN`1xh_yo!!Usmq5mgiN`P)^z!Q*ijiwai8~45=_3j<{G@6Tl!tN3ytjzqGTnK-Sy;P(5 z)R(A4W(AZjX}SxLH~i}#oZRyD4OBnqvE0tOl*CRo&jpF=SY^h_ZiiIX9p4;+5*`$8 z{q~9`h;YalH{^&75!+hALS-poeaMtQ09O~AT4v4O)n(&cof(`eo`a4)X+9n(JbJl? zHEs$WtVzm8h%J4*1N%x%9enZ*SqlTD>x$@Jk}HCy_3lxmB0?KY%jyqHW}OfJSl0%*dV z-72F?$L&^l!g`ktZfzE(F4)eRRD3=}M*gs?V6?R!muQ+Y_*z&SJ;|;5xV_O|kz%e{ zDM0kRV=&=`c_Ufq`1ruUjn*?Zu|;-+hfImpb}vPL@(#C)#;?%vt`sw3i5rwqHG=8h^ivi4`|R*Ud=Y_?OUlZ2e1Q#a27Np z{Ak#n7~b5aZlPL;F4(QP5P)=_6K|ttC^yoMxNvo(BxEC^j#-XpZB5`E6&f3)ycVw? ZHT4{9EEM?rA5s3_(X9XXrYN44|33%%4j=#k diff --git a/doc/images/inria-small.png b/doc/images/inria-small.png new file mode 100644 index 0000000000000000000000000000000000000000..2a6961644ebb212dc74f876e872d22453d3fecd8 GIT binary patch literal 7105 zcmV;y8$RTTP)pP+i{AqO0jL*wr$&HF}CfTmbbijc5Y_6AN=&qW#Xg|{&%iwK4VY! zx1IR^=7xH&O(BC>G>}C`LfUoTVuwUk>+)(J~ znvHhf>Vr1Y@xM*S*V5VbCtgePUVn1mKzmNot;QNg?sjX#AxurpRHKm!{YtuUBfW6- zz$r8`(9=c6J1l}J^bVslGy{SOAmNIWi3m_=J_ec-88ZroG+Y6pfglnGHJci4_^Jvucg!L>B92D`K!;r(XG98aQa?bD1keoN(A5v z72`G%&;Y@MfHqTeRrfCgYW)xts_F33+g-i)#9yZ)>kqufc;?0F4o9WQb$A5kY&BKlcum00(oVc8KZ( zYNgN|ouyQ6phRRG!^0TA@dy3QKg?1m*B8$1q{B<;D8Ksq?oMev1a*QF+aNiXoE_rVR*(ie^ItH>Uvt4Og512L*)=6&xIQPP>AM~+khl2-JL~1o?MW_=Q8>)C1IKq^IYt~L_pb;@!a}@>( zh-#~q?f3tb)x{$#_g+|8Jh%SPxwZ7%`18MkMxYIMjYh_h3Ib{pxRI@nz@YYF0B(>Q zGpOly-hP*q$&_9GM>;kpn=h28la4_5YW3q3+im8azHywg3+)1?u~SC z_nu>Gi)Yr-<@NOP&5!!DKyzsUrA`D0H87Z15z{7tl%7F^DFWKZprg8=KY++>K2MA9iTKcpiw=@&`{mnus`}E`ORO;zx&Jf*T2(EDiuujN*?1qm4>^A zYqmW{0xDH~2;SC*1Bd{jM~l^f=7H53rrB^*m5@#+LbTQ4@X!7_*GT|@R))0@>XeyC z6;m0(S{I&)wS^;FcaHBZo*FEkSWXw$A9@|SQ9z}kP{ADY070i8^byT|;g@c{{rjvx zd3ZfN_BZK~U=ohuv9w&i;On>E>`g|`99sH&Vy!a3jqB>pk{K4gAZA!FGjn66v}QZZ z*fxeRGcz-k+c2jX2D6u>6|Hb24TFZhcb~p+@;j552j1Ls@44sPH>M_bOjEqg&Oi2m zAt^c#ri|%LN9P;5s7O$&+6#|SyB1|;c=OhAIor5w(30%1#C$EOfGO;+Pjm%~x1Pl! zMZ=@KDFJnrcm-t&5`HCxJ@QrY^z%qp=U`23e`T%l`XQ(=NHHS^ZH&R+?%-aBvMV~6 zyy2e0v<;CVj&P+ZMl{B%8=_U_V?+F*llMURmD^?M7p*G&>q9FahRPT`dhPhosmF)8 z2ZuPKL`Rq~_18=^RE#2sLJ<{>qykUkEAV(mQYf9K?tL^%xQ>r<4%FCFBTVe8e*!(+ z*B{+|+Og3!wBJElI`ZRZ<~(lF6<|8rsm05AUDi?E=_Nr20#tDGq3Kk9*NOZLcW#j@ zF}*c4!$i(^u=BmlEN6O}IljPAnBpIt5Xek#RYP%9AGCpTgcl5+!IC5W^rNK3m-M<4 zJ_l*_^S9qeBIpG{)kIwaK?=#katvyBJs;Dnc5sA~tgyuiTZZsv2zQ#Wh6t3MR;nTx ztwOrD(@=9p3B9`Z%R{AoeMnd-%MU+q5EIZU2MJH#2wNW!Nd5U`LydP(g3iG^~ zS*H9VZ))Zc&;pnuWneML{_JCq@g445S1sM?V1v%=1D65)syfykI;Tm9Gq4Du;tMwq z)&xSM(e{mJu*!dQh&w%MQeylD;ZHA1&pj6-yx}Uxk!tf`RY3gwe>B#rzX((ykr}Ml zREfu%w=#vPnPb@GkXq#D6U5{wF(FpzCVzDb=psDv#GsVLGxU62rGT>}231$s)Irt; z`E|?551+r^tL(MuGqZo99qX7xwNHmpfC7{pl5`&TL#NtU#3`9f9SU zplh%GB|5VyH`|q(Y>rR#q#nLq$NoSxL{(>TGu6O}6f^PW3H>%#qHbcM)^}`-sh?;m zeDIdS^E?(8qCtiTN4m;QA|FP03Y$;m_43v>lAYM(l{digU>a*uwd@G-jH+zRG*~v) zS=bL?(WaxnQgK%CG}4!l?I@%IWT=KxUivAhbS!?cB7iWJcwUFhMBGizT9>Hk>a8*L zRJL~y?Ml>50{BE)g|_|md`v@+ly$sT^aK4@a8DW>-qG7^;@T#~?f(4f=Y_4NNT7MJ zseygy`Xbq5$g=Upf%H^IbcUyIhDVs=$zOFAlJO#bRZdSbV5^G73s$8UPlzugxbz`j zMO4f|9zbuA_{-LZSKCu-Ey)uaQ}b)3gP$6rj)cC3^GwoYg-)mDwMXZ=sp+kmsb=cx z^UCkY={)c>!n8bbe(dq|*fz1vF;HjC)SCK-w+!yTrflU@qhRnL8IN@>?<4$i!WAX_ zJ;NR1JMSp)l(P3S)bg5f=&{O|bgZFhLn^xm5nMe+1q`N)Ayr{|@^#GNCMx^}nQ#<; zY>?pppMXl& zhiqVw9-!5NbM|1PFJ=xDo4iq*CD$=I=kXRiru>?f#<2v-Q!Eo`a?!-XWIMGeFu2^Q z?p}|=hc1WA!>fgLe)xW2ZHutExkP3d9bTVTIf${Ya5tSgu+^5?Xs6Z($t4ZO-yeiH z2>1;(6buk`xmjbfi*3w88?(4Y+OQ1U0=R>Wds7LUS^m5|VTUE@Z#!!D(I!t~bd&bj zi#qID@gzz$xalOAecus2HCme1n8TE|AcxzzOR(QGt_WY}xg%B|fl%p+ec| zivBe=xT3zUDh;rO?HEaxm4XZ?e{GJ+npItc4LY5SW zGI#yO`CkX}TP*D6mhg%;^&fYOAj}0Anm~5bw4rmSuAtUB3hV5#g`MOIGp%wZLj>tB z0(B!tj9}n)tfISu1Jf-Lhc_Fr7tqgdy{mu;1Q>K3Qjz`d=e8=0DZ~TY* z@I(6J1ICvjodYoiPsg_KC!oU6Xwb|tumyz^QEpI$gG5^zNxot36ycT#JC#ZsIwX&z z;dze=EKoozk#$B9O5whAO>4Bx!H)Awtz8MHE9CQv?Ll(tL}ZF-@S;2(D5A(kjYg*l2i~DibVk;?m^G&KuPuC^SGrK<523_yS>_*qV`grSR9MmfNKB#~2_#Pto%R#H5P@!Pq6kll zXzeB%C;QUR#t{bRNQ9f~?|{dvD=q+KY?Y zs1^3yYKzp{%R?H8c^zzlh-mb_bCa7q^!tA?SW(rXRS0-k3Yq!AQ`GLZ(rK>2Q)}7N zEZmd(@ppxkqDeA}3WE}f)>Bb&gNNScr8YP5TkPx$_sL*@(z?u|P9|xh@$WY?zcLq= zb|vQ624}aVr+AYSI=C+m=+K;Dk_k#d(Dc_nOw?L4&25neccR7}uR%d*&I5PALRcAG zuz@wOEO97aNTXmqVd%JwsKedt<~`|A|sgF$9T97CC}4Zkw(}bIFbu2Ow?P#jrN$uoAPf9;K#=6l=;!~p}AsB#Lq&%&m zfY8hZw(L1p>O7l#wzqWETIg;X)>JB^Xo#`3+`T6i&o!eS#_Y7GPHCZcxhO+Ii3hv{ zVN`PaenT&4WR^H%%dF}Q|KI|9ZeC|$)qK2(CF+_|L{(`h%oSG-SX)C@&p>N1G159v z>qJv&>WoXV*eHSBX9ZCgRhfa>!hRhslsbMOv~D*!d1?PJTb8h(;ZD5DldA9?t8(=X z@ge|viPlt=J6Vb3i5ZO)VpUf3H^+%_8NwVT#-#~Qq@pWN_>c`vjGV#q;!i-OmkG^K z%M)vqFJh`UL3EM?dVu}VNPl|lG8IqF%RWlO8O48oB=PIhV?+zp;7*UT#T@2@#To1H z$Na4!w?j?=;jkL)q=-^T)<#nbb_q&^QD;e8;2~ z(|gZl&5i+=Ghy?m8+@tO&RBzGkg&=IkS9Wh$2D%~@H7AK&*;nVDy-a{A{?+wS>;Sr zd;17?yviR%ELY(~kVJ{L6l!u~yzvbeacAAY?zuaA?L+iae^dYSxqh@y4aJx*zhvKd zOT7Mx_%#niC;lo#)RBaTCyOY7eh3xTL4}O4Z z3eJpjg-^eX)Ow>sTcV>n@*}*V>FvojPtr7g(B+5(ZOQ3o1Cf_t@UV%&sXEOVf{KFv z+NYzV=S|Zx#HcvYJTSB)N(7@rI*tu> zgofMmge^vRqC(_*jo}&8_mz z3rUwL(e5AcI)pBVs-i$Mtc<0Rh;$pwZ;-;Lp1qOUJTbMymftdw+uE4g?vQp{sIxr9 z)5r5?nTx*-$a{Syfkgo8itzP59#@|KBHft2s6o5T&i~FKUExn(ZsRU*qOLdPuD4Ua zA0zB{7cVj8_H?liuHi1OFYK{U7kj9S8kLJ16K8pp(^~WAw#cV63p;Jnk(W_E((9b@ zjU&jl-u(K2yxbL^(?-uXrRQ1t=Gs%!ZTSfuxlZ&JYT}&z(SQRnVI(j&Jj!!;w3Qs; z?5VM4s@o4$w5N#~hbp~MBG69+BSQj_8c*NI&ch>Ta_{_`N0};&ETZ2Zd_jH`i@ ze&#ULvg+3*@Y3i4a5T^|6a_8gdkPR>!*UMqxMMFTF|p12knl@X29Kj2NO+4xpl?|F zp-N|BWUzm9dyJ?Ro_?ex$?2~?3t8HdPH)uX?Jcvy4)=e?b>nVRaWC_``{He`sN7+6-6>$dGG|`Qj4;$HNU4xy`V`w+mbuarJw6y zw^teedz=Ny0AXNyQ_>~}y|GbT@6gw`<~KIc3nq{Y+OqQ-OF;{N@F?PH4h8m;o^+e9 z$JUUw`qFg{q?qEPeaA+*yNA1wfk{_5!b2LO)gJ1RXG=#0j5x=r z7$+K;-M`O$@IC@TMvo^WIbzf?)41>`f<5y@Bc3&M2DU4JyJ!Fi@Av2mBO!k0sr-Ex zv#SYuU0Zx7QQTRPy0%OE`#ntTGi*E*>{5$@g0Lw;r&6&P14qR`(($DH)&26-wZ^q0 zlb6@amk(zjSs?!J9-|Nj1*X9Zl!R>d!R6#7_0sR^)4w0d|87+2vHj|iPi0X;z$+RB z!2_u+Wj}l_wav~gt?XY>7hPDBS!gcYv|IS_LnIT>A?gCJX5o^LzsRBaHnk4E)s6di2w|<|CKOCN7xz|;qGuxAN!J&J-GM(M#o8c z#S!zQrr&c+h*eUn7HLj9p<_hYKA}#P|6Tw7kCwt#_B)SO z(2iY6$KSSssH)99>ppJm>o;S4nq2DDRXj4*EPbn?*7?)V+-^@c3v$IX>~h@*Yev~~ zuKwfq-}OhAD?JMLOQ(mn|5kf>QS1Kr)9?PU?;AHvlap_Qq zB@cvNNnyPF#?bzhsXXen4re8AQ}xO*W9?eCO>&F8ctjK}WG$~xTBL66-YH8M^*dS4 zK0dwX64RpExjahXK|ii&M=0xZ?B`!|gW7Le6HBN7dmr0=n$7acGtsq8z4@cwdz-90 zspZ@|UtN}KJahU3)>`MB&vF&Xb3*jYe9l*m5T+F>)oFF*$TD3OwCQ&|%MO=m15=$W zGw7G@P|+`AsA%o-NxxQ(((O{4_FR}_q*Bs5&RU*tDPyrk^_q4rCx+`4*I`=mDx2py zU8*YbIvJZexa8K1=laKGQ_bQi>3tm2?tR`Yl!w*b`X5aNr9CWBDLW#p6xHUgDsq`^ zf3loY$93k-RS)fOPQJPl$0FztFeK$R$G}t@pW~}W)O4}3N~E0~!*m)0&+T4F7_=?S zUaq5c7n4ksoT)bs(~A@mrjIC|tt7c?uw!r~j;cu(?5Z$n-DV~+-Ulx+V@n6NsUjJ+ z9V00c)vcb)5cEn{={8b|7^L@`i$*cYW$6tz(r{2JEjlO|YlJk+Y1w;9&5*iuce++e zA8=S}C54{B=$!vds^-t_*|7lt5D>t?-+yNSfJy~&E2K;y=1}A6G~YWtRjFkRsg5`A r08)WeAQh?BfmAvpl^Bo;pQeHZP>p->{lr$400000NkvXXu0mjfgw6j# literal 0 HcmV?d00001 diff --git a/doc/images/intel-small.png b/doc/images/intel-small.png new file mode 100644 index 0000000000000000000000000000000000000000..42f63535855fd4b7cf20c9a2036fa1abaef352c3 GIT binary patch literal 10935 zcmV;oDoE9dP) zaB^>EX>4U6ba`-PAZ2)IW&i+q+O3;smK-^Dh5us}S^~UT4(vJFftIiD0$FrTa%ObI z>SiS~BO?&cTmvA?{-6Io=D+w;eTychQgh4M@+Y>~eCI{A@6Y?~Y`lNpKOuiUb3eat zeEfXic`NA1_s_gP$9LXO-@i~k$Gh?K`=-p#dFtmxpAWuH7<6UNi}e}g=S0E3_uc2; zrup}x=O1q8^Zz>U>%P5@e*JeW1Y;%67rgO1xM1(!{#_|BiqJyW`~1y2UhjBb`%v!R z^`G%G-$M7B0QUa-_}Te6B76_WJJHu|0guE%T$SN`18*Dc;x95^^txL?cs6#h@VF88PLQ*F1CtSxSJ7(4h!bDqc^ z7u|B*9k=iA=_W&re*402@23y3UqO za^>YX(mPJ3EdKOketqNr`t#F?&QX~=r(pJeW5u|<>Wg70bNZX3NXRD*Y7>3p%==}3 z{iDPNGMH|dD>vBfc&8Z2eXp%>^Bin492)KG7gyHj1cZodCkEpJ8L-)fJQfg_X*HB|kHP=#WZM8Sw0x*_L%dNE9TI-#Q z_SiXB=lISW!;diHNF$Fj>S&`+%4f!zW}ao%*=ApUg#`qdl~-AHwbi%Vl+unn?Yzsb zyX}6cwG&P}>Eu&RJ?->wS^H%5=dAsjxxZ)4eX{0nqI6vOo;BVsD)_ssn_s_ijCTpu-xD7W!>Ox`)LI&E819t7ccJb05f2SC^K?31mR2U$vJ)(Du zM0zvoT01fFy~ZlLBzIP(rs6K)RAL+B?i_spu{wA6+~wKMwUt2HvcfrFUOA?&l2Q-5 z0+^wv>023@L67RjAht2(a7x{Qo63IlUbFRmrB?ii#G$t~2N^ae<`zq9s)Cp+0qMII zu4&hb{4OIA6_I0H`aL689%qr>QqE-13fgIg0op!F+AX=sHT&5_|15igA31}@`Q+Hw z?#1dv$JV?%HxyQ1T!?XUi$0saGl6zxERxK)Q}byQPi^!;8s}=$dcHJzR_<$_o1m7q zIDKq&%~ioD>ZtB@QlCdxqd*6@!>e}GST*bQc4XSK;@)2Jh)AT!XZw28Efq$3C$Dvng|g4`*<(MY0&% z^l^3lKKFz>hC5!{)91wH+lkHR%(1LY@0`%jSoU*Ps_=*NzOixUJNA3-&zX5=8|U-l za601V#JhWfN?qO&zX9`)R|K=3z4P~25Qzpq5eIWat58siiDoQ4jy64+VF;_dD;AyA zQy5?4?LUrpq`9*hr(3#~QkpGHK+Mzez#1)7fgRpll>v3}dKxSA-FM3iM5DK$o?^#v zwPZ6-qD;sl0|I0`01+CUo_3sthD*h+u(_FbS`W3nG6~$=y_dZdOj^$H3G7!>ar7{Q zoWq7@={_(q0wm2Amo~#si6$AgIpAi4LBLntkqAW+z6~iS8sgL=9D|*z__9YgwL9W-uudQ_c*M20LljNH6+Sf?M8!YcWY+|Mc7F zMIO15-omg-qs*Ek_l@2?8f^iJisO2V^ixtf(JC2~8DsXk5{CrW%(RmN1J(Ras%uBR z?KyfOHt~evpapu3;!1vuO#z%3T3Ig4um~%32jUNp7!9@A;X+N>)$M8p>WKt>(dQ=3 zT-lMIJ%MX?B1M*f7igh@ZeypiH*q--WH>X`$?*2t#BE1k7mSmr9_j-atP-b$KXJJ* z{2dhto@?K>kQ@i%QP~v#z3~GePKvmIyMtLHrjGS{9 zjr*suGjtH++}5CQ={-oW-#(r`2Zjf@x*?{rB=|uc0hg{&=4q4_Qm&Cw_JE)WBk*CERJp4{bNaTUG%Z;| z5-M_3k!T_AWg;DVASf=B7hghH&50BPh(bkY7UBg&me$)MUFgA1^>oUg4r-+ewD4yQ zp+%F9c9mULt_y*Cftp(qdg> z%rrxqB4~)@67*wS)a=SqqXp@pse_sEUT9H;i4O8 zKaN=BCPlASp?{7!waC^wSevpqm)CV|%op@dT|-xV7xtx)7^{jqhkc3O&Gb!Nn5Qmg zQw5$m>Xmszblp%^5G$GsdA&J$qfKLjIUq!k*rNhkJXn>PsqE~DM{Yu;+{o%d_(QK~ z6Ox-ES)Rzd;t<}{u_B_K89}2!aYqf~_#% z0)0D!3ZTWPDV&Fd&o9-3X2t%tBdN{-+6=XFIgCOn3?ZlO_CmSB*(}r!! zw@n-7kCv)W-jJ92ZrDW0;8&*54@P8SL%@SuPu|xeVk-w0hg&f*A}(oLdNyS27)Gj# z0nKU{mJsGZHPcMg1YHk1VY`H}_oDx74ek@7IqfOWkuWA@{TcDr-YIM>78YVVo^-ND;PmLlG0nq4RK#uI-SXM~dtQMHC3Cc&v;X2iZ z1afSh()lX2U;(c|wIXY_I8?HQ{1@ab0Jh{>;1s^ME_{uU&+Ils{h+2TgOnkCCt0z; zv?Zv3zpT(s)GWm;aqM7w>=-o)GF;mXbnXT6q%P3$WDb>vzNgm25?n78BeaG%fnj_O zPB%&PUbZ0s1Tj;@G<>-rjKkbn@K;w%--e1z>${J6n z0{uY|2b(v<%HLHrhhoUpQ7DiG*h4lW<8y=snTal&$UII7$07I(kbQ(rxYOwkF*+gatGdnk&dZ(w_(|%mN)* zVbQR{h(Z|~={0<)z37LfrEHM41it+7a~vd}k^h98#bG!oiSlQ(D?*_36&8NqQQbmT z8DoO1$~c zi7E^*dq4V?2KA+^<5T5uJq4sRq+sU>LUk30)%dNVN`!WD27 zOED$_F%jz7MZ6FP!-T>RE?+MSnK_vJLdHn1ik*;5^Gd4VUMHsWOc)Dz(Pto1e#ESn z>38M<#=TJ*u1k3J(-01FY|%i6j*@BqcuoC)W)83hzQrAN>GhpZ`a%q$qWNB zsH=sbnNLWtG9s$w!1tgPXmZ59x+3%oVBsH-Rg9(#)dBc4^nJ_&UF~GC#qUksL;7VN z6TZQt-wQXG&XAMuni&*`%}(PGB)%@~<|s+4c`hHRQ2L5tVNZ-#E9NQ3D$Ek>jJktG-w41y2R5UDzgfJ1x(C5SbLZoO&W z|4jRqbDJIGo{SZWz$1kWAS?1Uv8X?;ERVp;V$tD1vkT5Oxn>J-oFv3Xr|6!ky<&19 z8TlzBB`MnLNaXpcJpf|fB0lZ{5H-+UkRMOnFa`!LY=g9VBmE@@1-oi3rUHILvnxb( zYd#*5GyxP>3V z4L;EXU$bN00!`Cc@oF4`+kfi|U&n91;uEL!%vH7l2Y^Gu017cUkZa1-Wd;bG+>-_x zgKvX{C)1GJQ5C&qF7uv$thfe`;}TGml&Mr$?|Glc;wTY!zFI(vX(vmkx|C3 zp_#8`@4@sk)3fYosC+FuO{U4catJA+8*ir>C$A|Wnj2bOHHvJJW||=_vtnxu3MfGD z;H#JRM5SszaN11bJ1`*qEi~GvA`dhva0@{~qRZGp=qb8H_a~Ll)Y>vCU`8QX(If=! zCrhD_09zjx_U$g1D9gYxQdJW`oExk&xGoS7b%g93L5A#$<19D^d<<-nCK&KrcI~L? z9B6lw7&6)DbgVf;G3akROQA#eteVbV4g@$i9&WEYWi|mS>TG$MDWX6^ad*lDc*8Kx zi%3RZj*nI9q!B1X$xBaW#0^xO?qp}-W|n5}aF=;Q*<{iCDB74UNDQO14+*#v-}@EL zU&}_zE;Ttg55~SMA^1bBIlf)404XkD%NIiZ2Wh=q zl=}H6wP{gD#D>U_u|)+1%<=St60o9k6b_mpJZ*~sts}uc^ai)(#!$v7#k?=}5gRI; z0MZMGTO&I+z_2u)iqZl&kQ*T9WZ-}~79PVdy$8b4R%dY6r0q_Vx)QM;V1PO0b~(sV zmQ1O?bGqt4Y#ZO005znvMzM&jVy9i1CeGXr#tZY0YvckR#fBx}D8`D2wWLSRm_ckl znQze|&w!~z12^3YdEb^pBo7>!K$Cix(b;356let~2TmrWhydumVDLP)aDxPxMtf%f z;J0+{WRG%8k06LG0ZGYdYGgN@<8$%=$`h>d)5yDb*C}XBw5N{3tY4oXs~LXzh*+Z; z(D9{W$A5;~KjOUkw}9UL4D{f9`4#64p!^5>q~(@|_@M`rz;SiCOHNxi0Y#Jg3(sVh zqX9O_)0dqrB?EbH_n8GL*Off3M|pONfl>+BnKvsjKHlV`As1DG-e~ZqgCfi3%t^~w z+xn9Lx+XJuKIHu?W|PMInfcv^i2TcF01{NsoinJlv5-A;>%T{FBN%rZ9n}8@^g06E z`!Xz*;BR0#1u*ah!T{iE)Gre{`j6CbbLU?Z%aMe#LwJetKP4tM=*UTLd>Np0i#640X)5Fet$`}7otUJLw@+A0j&0_$zd6$Iq@OIInr!^>;O^~{B_W>MZmJU z2nh2|>fTdWZ`v(?y%_WYwl5?R9%d>Xge*FHViyB?+@=b!t7sug8Xe@5$Sj(){nIR5 z2K!pKxbtiVPS7c!DA}5{9ssuh{M=~K#18NiVz3A4?CSy6H1Q!4Cx*28q{FG(zNg3z4s!$tafR00nGuAA&of_=cJe*hV zwTDrFjW418F&EmVQM6334JtvH5(gC;yp}|Z=8|zxOQ=p9gn_^Sh%$v2F*waF(au`Z zCJv1dk7@GF8Hq)%XlvTxQXn%b!}oyZ7$OA<-fmRmh;D3dQ!{&ztfurS6dEIRGeZc- zE+-KheWdmy1~lL+G}VWW^cW5}ax}h!FKUb$LtUT+M{{=SK+77-@Rd+wj1E;9t>*f= z=K3^M*N-sq8y=#vzTx5BXSrnFK$WezDuaoJhDBfn(kz4zsiBQ56tGJohLrz^>zet< z?=t4vuxG+#!AAL)U-{)5TBaZ!Cb$EtYeoD{W?Ao*GSM=w4$F+}e5)Kkg5T=@;6U>@ z4a3eE4?PITVCPXLjRC0HAX#bw$Mq-vc=TZ?U@8Y%&(ig2zZtCh!E4ndJ|djDkHUp# z9dB;SLRQe{)Yc#@bv6(1tGwI20ZW`<`}^VVK3f#XjT_~^EO*YJu}^oAf9>PlLAY{2^92$Hkl-=y_;O_Io;;c5B*LRR+E z>!Q(_ixfH*ivPv897*4r`HBF5;(O$PTsg&ghh6_`>6JtgXaKp?UL&>&>R=p!6a1R5JHUu{rtN5z@h?IvOIcPj39?izjiT#> zKqx0lX@Qg($lVuEB;JnZqJA(Noq|*!JME8q+o$rq+R3yKJf#WDdrSH1LwG6Ts_BVz z1!SzgSv)AzhVSpvA!vF#5H)V-yj+j>%m)W3`L~vJFNR?8d8Sq6xcI?dUEz zy>iR0$iLBXk&|i@5`)a@+qSt9U0}4uzDDb+c4JNHAyo0dSX(auamH%5seODSL4eyR zFEhf2u|Y{I*;Vm++BqB%5jeY1>8GlB!X30MEBQPlk_%kkkN_r$0YyVRXxNY?8X$gb z-;ffGnA&cMYc6*-NL!2WAP60gHr`rhpo~dhe)DrUNZB5Y+oMFZX>y<6Twh$!htqdhuq~JMd8Bb zzN2wi1uQj2hU`ToE-D8E)KD3y`Rv&{=dI%3amGoDB2{UNS^QckL3u2Q#CfYp8SJ5e zEuxNrL>#hBrZY%*3@ai>R4dYfa0fkukUSk_4|jO?`cEWiXF zX}acSOYItwU1T_FC$l4X3wfS~r3PU=)SPBCrsTx0dq}^wdTOgp(MFsSmKKe$GPFCs zW>J&m*R7@9zOyJfR>hCJSRyDv1lq;`pDrTKp%A#5AW3{`8uX7LuQFDsYa}N`Gqovz zE0}Pn08x3;=yxki#lXQ5npE-d7s|DeWZI-^ag;VFxEDR-A-Nx;}-#4i8In~W7kGer@2Kvm_xm}y=8^r{{eh;veIPmupJ#^ z*$ZYZBY9h*B9Muzp5RQyHI_CgiI1NUh&(mCe~y?1%@QRDslpw}kCoy4RoDzMog z$q4whu8Az|Ox9BqJqJC3r0vMaBO5blO}|bZ~(27AQBWR{f6H=&q~$ z9pW_e1kcCT4F(N9a@y3NEJ6iR2&y+ZgH;(&#s>u1WG;v)#ZM+q1LtZFJSXk4qiD$C zfa0S$zW}>z&G{`^elVeH9hIxr6uWlZIgm~pz5zvTkf)KP^)PLYB%7jrI3TqUu-w`u zreaf?Oyh@FI!upU&dVpIB&o{4)TnJpbd2Ow>WI-Mt*-fcv8UEmQ94w)!>9*MW&y~^ zw4?oVn#+=jU3!wI@HL28&&H%Q%vhQk4uoF2os0b}B4peM3QLQBP!PtYEB0Rg4fhti zw&iMr2Q94i!)49e+UAqJFQWa+@6?^PK~tof9`B9qmZNrXooI53b*2bF5sr`~;$OoNlxSgSH;PFe<&&qQWJ4h~z-XJmA6|Yp0E{ z=w$8%So`^T0Xl~X8#%z(cT2j(mOJKF$MU|rwVjXH;2p?^Rw7B*OAa7)^uSj>JHR-4wjrcb7 zNOXWDKna}PeLe{jPIKad+@c-4<^X!8#KYa=3Absp9yM?79?yIi%%h$_6Soy5avfR< z(T|WEprcP(x`orUN9H$X+5E7HnIJ6^MZ`5VYTl z1!Am&Lf9S{Ro0I-ALOorL@eqV7_7YL1f{E|T2^jSo8F)naE|RFs`?UO<_7WFs3{7z zg)$_5UDWgk0NBe# zOi;yEKp6*zu}a>I2=L+5+Rmt021yw;NPN%xxc)tba^XKEKuOdVB;ss7flk^%eXoDd zONnum)dn|hu=jb^6#|tluBLmJ=5Ey>Fkus8YD1im=n-+%_EbHYAc;T_1LSY<8w3J| ztVIZ{iBkZfYbvBtxs#&P;;^}@1S`FRwt~w^YoXGin6-J@FUd%uC5<@Mmz*PLV|>=8 zwxit@+EGS~ww2IMkj;JAMoSw+>OMN)Ebd&hjODIc&1l=IE}^i|b*E~wK+C&;_t3Nt z1r*2kW%2N}ENHtGlm@RK}$ZKlwc5c ze(k3ePr9vbi+4${=_b*Hc-y@-Jtf69bz)?N{Z~|}IVli2g!^<3&Aw)unia(_U+L!w zpp~oj?>QA`L2F@^I%(b+=<`SmSQ&cM1qc;aQAAbs*n~>PYDOdu2B|69_tXWJR66)J zjvt>#E)r6>XtM3v>#k>Mue}C4{O;ChbG3G8X$2H>A=hg6Z&%r1nmpCtG|1Z*TpaqA zvw-5&XgjA?mfxNVWYux|FG@l*HuT^JLP#UI%kOxuXuPE;V!aP7(#erx*3wbB624k( zD#_Y!3$WQP!L^q5Vru_INjW%em8~3S>hRM}N@_OQTGhx^|8FnUlxDrpV1)}`^Y$aA zK27ahnOX(!PK&8jHqHcA=2GB7*ovWa+7hl2HQcD`u_DAI)R3*19iMQ3^bkEw!;;#F zJVxZ!Z)QkO%AN0K=($haTvFOPfz9n+H9Kw2zsYw#B6|W%3!Su`-{vF>@YK6GDMURq zR5kIkCx~F%dzNLA^Z|%$J?8jHLd`Ck-#%!iNpE1p{4ZQA6Q%U+_7wmC0fcEoLr_UW zLm+T+Z)Rz1WdHzpoPCiyNW)MRhX1CDQYsD>6miHC*an{wrRS*OpATF*>iY`*(xk;f#j28~~;eVn9G7o@It<6K@bt zZ`uauec}i!$|~_W@u*1+58V^%aZ;z{DDs_B$3 zWIR?mZ*kVjRo3j2zc8HBSC+X>HG~8fu>=Vs6jV_{88%|H>!esn(Rtj*Kj``;aw+60 zfstbY6=;xMKlmT~o~@Oi826IGaiH_XaXyBDu3ex}bDZyE$7!4Z!Drw~Z}}^AVCIwb zT1$%_0lnM6#dS+l_JGSBVBpD+P1%)#R6;%vyr0oGWr4n1AhhPqt$B{q2Ov$eO5Okm zhrmdIve$jy-Q7O7fA2Ks_XAuca-O;a=IsCg010qNS#tmYA=Ll?A=Lo{oV1+)000Mc zNliru;|c-_8y?~lV2J<#2MkF>K~z}7-I#xD6lEI6Kkv+RyOz7wLIFKMB@hrrFnUs9 zw(W5m!NZ8bYE+CrA{ww_XOQ#sqTHcy(K~Ur)KJcFr$!6}L@3lKqN3d$5R?iL3K&2* zL^%%xJdyUgyYKz6Z`V!uFLjc_aay2j)$K2BO zz|0c3=1u|Jh!E3&aXq78#e*@0x_|S%91Giw{q6Vb{_%C z)oMR5636_Zlfc~mz%U}bPq`^MSIBms2Od|ec|^aTnFN@G($FuPB2ii4PiI`|2R?>u z;~}6Ia2&_1pmd`^#I6N-f`4zTrOyaInzk4C6;O|q{j?5+7QZXd^do5c4`}*(XXP}6 z1~3Eo3otXzi;hH41l-D!5-2OE^O$2MLQHH8Z0zsB3Xk#x&4?#=q+=$Vv0aJf9%ouu z?-H<5Zb!-8cc{0K;1X>(T3CpbD4+!1A$@ISb|~)tYG$FeMFz^awmt0l%9BzwYZj z)hyr_fTZq!M_B)&WY?O`$(~=-fbG^PV8ZZ(@=R4NY$0Mi2S7=s;Aq}{;5=X>FcE+q zZY4JPZATkEK+5R^1IH4cosZA>mSZDgytZc`a0<)qh2TY?7{~&^@7u%pCuXo~P3Mq3 z8uB&C06ab>UH}FGQz_4nhw=s?fIOf#aDEFVp>G@41;Av4c%z;2t+%>_0#yUt<5dV4 zj9?{}JF!#1G;J`@7{T|%fm!Zw;9=mOIOg7FyUk1{?i7-IkYj)wfv*&Q9G2TRVefxI z2I3cW2%ZNnQbFlyq`XR5#D4%C!2)GPsrD(8JgXW*F=wgMThL;&5lL2^#0Z+3J-)I+ zg=7sukA%GuPKi%gXhyNBHP9z3eie=x_i3Sjj#d3P;IHwVB=SL4Ck17>eswU7QS(g$ zJDdmH3hcr$x3zpD&lBytGB@d01O6tnw%unvPltH!N2wkx*=fxJE>x>Zf`NY-AjYfH zTF*@mRN0;rtM{}(T}s~kS){JED5pay_br3me&amlR0JuXLglA`V%D|*1yFB*uUkq& zuh8WQ$<<3cI+T*z?*`RyHh%rN1OwX}fnBbct$QecgE+lzKEYs&SlnNuu^MzOsub#j zV=ig#^F9PUqbGGIzUS{Z!IV&INz+3DLwgE%S~a!+Zz1KgUX#dZSCyBFqD_@U`gEkp zmol_JV7qrAHGri}rf9nrSnepS@HJ#)&PLM?k(GNL)7QhI0$Mq`p}103dJjMDM?rHb@F1|a#XG^kK{dHrtmHM) zNnjz!0Ef1z|BU_#u! zpH-cK?GD0r2Qh6!$6qi%tfbauXs6t?<%3(uO`~>d+K^!yX zNm=V&EAMi+dto+0o2_~rUlJJ+Pf>eHCGsqqi%da!1?P5P1;~Hk1UfxP5=lnd13=6G z?getYIb77mErC@?`8L78Zie0tyO*U|aVMjabzt~2(L^DLV&w~i6Y9)#Msi;RpQ&1O zBgic{fi9agr#I-c!np`BOm(|MfNY?LN*CG9GY$@`UkQAn?spJlA7$o`L&^O(Pbb}x)E3qs~3P`6Ci_+b_q;$7SsDN~>pola`mvkfD(%rHkuz>Kc z&+|U-_x*8S*EwhA%L$tq^c?kFEKDMana*40zCBS$5%rz zdV=ZsQdth8{0Gw>dV*sot1gRyQT2}a*5a=wp___)C%YyG1A}Q?RY6wU z*YqfdxYAnLppT32;UhwR%H@ZN%D$YBFt?a-pIgF#>*2WLd`TV_te#DFlmtXgxZ062 zPdJNkUw-C{qUH`S#7OIY-xToWX4nP2MnOU*&@wH9#rBXw&UjRGZ+UU@>|}CjN7&Bx zTJLv>#N?mfKK}mK7daAwUq;OvE!LAdfYKB&Si|l7xOG7HD{Tj{TJcwnA0u`H)meUi z+QZkS-|seH^m%p3D;F@(ee^0Q9W~tYOUx6O#hfCP$EEydiGkH-}Z1}zyB@I31hAi zeiH9@wG>?|z2JQ{DHymaA!g&1IpgHVz}nci7@WnHR@V5eIArF{ZHVz-^8{WLGY%3_ z3?8ReUKK3kIbVM|kzL=XMQt!AjG0EaOaO)hjc>J=Jeau7S2M?a^?n*R8knz$2#?zs z+$HV(ac&u+k4b;R@1#%<6TX)6>MWgg99{$nZh61YX>Y4I?b*ArZL|gCx`5q(oA^-K z7G$i3S1+oV*ySHU+_!bZ5)29%?FU`5F4dtw)p#A>FgQSSo^hoI?wx<>ZG$c6$>G|a zNSiJ-$5ks8Y2177-WZOQztGgKXPP6WQKHTH#ZoXbKGqn4RK^VeduHH;F}&@*|7u9$ zY-K?B0XU;1i!TFuwkxY6FqHQZazQoaFsas}rl=bMm#{EBlG{~RCW4wNAtQ?GEQTY_ z>A&T1B>NO+3iq(Qw|5ji#?ZST9E37=lf*jGwdbGV_VjU-K8Gr^zZubE80}V^Bm!NQ zB=H%OjvyhI2CPr>>HdTehwZEH@u^dO1L4@UfUHpg+5^1h(YB2y$_@L*TMG6$2JZ%q z)2jqeAEUNVa>R3$?Y#8|1bQYaes37Gqrn^7mKGnQN~#Q8X35K&-Xk9h7Y*aIUT1e1 zSgbxbPhIqiEcv0M@nM?W*to(4w&rM+T5wI@f{9zM@Gr!$2W;3<^#PY%$?iHm>pgPZ zgHmw$h4wFQs|N}U@ziwjz1O2#AtM}491brK!I>ght*UR5oSMyVdjG6e$vfk|Mo2Y6 zp)m_WiTZ-+k{BV>#SKD;AN6c$ntIBIm|gOhVL4MCqjtV*h?YWPY_Di%#@u^)Nzkb4 zOLD%hZW~X0GY)Sv!y~dFvf^ZH9rM1m=F#)A(>wR4l9}l{F***?0{CIhkxI#4kxvPR zW-gTk@Yf^{j_ztZm~bHJOB;PUPIKfJvwwsGdPg(vtnu@}f4<6AnP#%!1T8Z@r5e@# zJnc>ze%H%C8Y1%NtAN$W)zgNn+@|TaBSC76=9%(cGm$U6`otowlzOPED z?dnJ&L125R)&ep|o#_nF1csvdKjFzqCr5+;>S!%z`6mForfM^nraHG<<2Qx{J({FHd4U2zU=H+3;!6|-6k0a;;@Gft+^tBg=SWO;q zRa?|%j#O)rPocSD@I3vL(F-W|=qwdH7^xPIkfHb(Vct93ox$4C*pKAe=*9g{R*u(E zniuD;OBWCLBC%n-6=`5MlVwxogw-iZV9Tt!qxi#Ix2eC^PVE0=&MkT=%TWf((W?wN z*Pt47fU$nAnP{ip`IjDO*74AzmH3^3ZN&aCWCC^3Mz0;sgO0UyFYi_pyO{#j~Vi5=lXndED63TD^Pf34?f6*YpUHx=Y)(9?gQ7xMC#_l?5HT+57aIIB>uEs z%wjvSl&csh>-80#=-{rN9gF**a>*T2k(;0b$hT8soswysXSQiPnEvL4#><85l=gw| zUtlMjg_DJ@ok8_HU}V$sPhU=IjA*3LJv1z>=yY_ok(*sBU- zb-i_u`$(f1$zlMvb7tB?ij@O(9jpBUlb(@ZF1L8TNRtD<)KMP_4ysWGTXJEqmCRNS zn3$c)&O~p~mKYhG2SE_Qi5EVVw2IRM?LD}k1sDYLV4TB$*kKdaMan4y53wm-qZa#YGhd@#K`1v>LAg#}5V|Zgbu2Mw%aXxtF*QLH zGNUc2V4zi-M;gdApzw2x*#8;5CF9%|)~qI(DxEDPtF|tQF;_8JWLos# zvipqpE1``8FZ08x)j-}5dlQ6ynmq*3z};jY070}Ea&?5!1R&=aWGzad30v>i$s)NPT(l8~ z6M-!1Yhf{8iK$9{hb`EX1K*a4(`ObQTgx0dx3V(iE9q3$^9%fc9AIss@gMim0R4}+ zMQi37Q3UC6um6vt!%*dASIuv71}9e``&b$GexM^4Sf)|sKVQ>tLGg+UuHrVx3Ugx} zG-Zmg^}4`sv9UOuCGu)g|B^5{lU^_80aLj(ki_{V|uDxqquAc zdY3Ai<|k={c1DD1;h4Ub|HgS8*pnQ)uwc9HIE`S)@GWl=CZ^5W*NwnAxlh$68?$~y z=|n(V5tZgCTr*4}D$ezXSG&mg1N%}KE!O^4ikX#gB*K|oBYyN=&k;5 z&^S@o=@Z}80MJoFMo|a4F^j-%9%dwTpY4PRgd9FdU|m=_tb=Q@+bS z_KEVCb-&!6oFb=qHC7*OH7wp#mmZfG`0P-xAx67UIC94f2YWqjYgCW^dT2hxPx7!? zQ)<2(QYvD3px&6H; ze-?kC=RpFlhach98wqz>H7V-1?~Y!|2FT^(fBeiBLO~~FS{D)2SG>y-A+7=OlTqXd zth}vXqszRYnKGF0sdMinWzd{7m902FdN6TGMmb{st_;Rcw`G>MN0pqXEer<;^wBlC zoh$U?yyAq{RptdwM#=L#QY^h!CFBYdsZv4^B9`2*mbtmBXb`Fh81Bu!M-?rJ5UCXH8{Y|A2=JgS+t)mO)fr0P48*q45}h;@u!-kAoNhEEBdvBAE*cb57-Ddc1N812`W!}B*wEB?jtB~?KflZp}7pN@!&To(W`eyV_N zevSi0QKj?&7%l03aveXS%^?e8Yg$RG(Si(oWaC^ zCq4oQroZ5|pW~6EQTVj!Rdpb!bIVZD2^d!~<&d55~*k_V=Qu!@fqe zz`9}^#1+Kk*X?7llmhxF))7KY%aT`Ru$vRm>BLAGt)Ye!8&^w0RRP2LkYbj62ZwcNd%SyD1TCA9(0|obq=MhZwvTyt6fwhP+!w^%P(kU* z!^o2&AhyJtp+#+VI`05-Gm48*2W4TkB0wc0YS4F#Glv}lB4nt?-{MhyLdix2k|tVA{zv-Z1@V1X zb%bx*N5{@8&%cSBL9Lw)FTb7WRBGFvHR+<#l%?tpBj;)lc6{Ju(2CJ z3gNh3dG;S7BVxl|9-aAfz^AmH9j~XqanJevEi7odVR74{^$LCCsfo;iAZIJ1=dU9u zprA5hsu!THyy!yNzA-dw75)^4r!tK-XqxkB@OC7z1taqIHjI0aEZL{X$Yo5RhJY%D z#Bfcks~BO7E{Unm{_JC-rBL3dKSNq#I}suaYj&DiTIpvMPuJ{AUGso~6=kHcvUnk` z6vE|1gP+lLXrPd*v%G|r#L6UdOtlc$`vLyq?XgXR=NY*oo zLF`|OBG1SPs17FZdZx(H^m*esVbFL7A)snb0!T}00yma(F`b|*X><#Qjw)UZ`%-BWU*(oG6+Ha9{zYuk~yq zqtc}=FoBUpA@4p_kaDp|l?VbZ1DGKZBJA(Zw8)#c&eRti<3QIhyiA-01?k<*M!Kz) z3_qm${FN?siayS^=Q!auCd_BxOL38c9>T{#f=zH4L+)CtTwMs8ApmBFh{!r_K(=v1 z#yM>ABs0D7Sns6{L!XS^D=BjAcXlVZ#hZ8&?PEZ+NI|AJeS+OTWOHP4sT)U@#3Oq7 zXQD#I*C7ITk8LVrq0t7#rYO$Ul+g`(gx`C!H)#51`G{D`=1d zqp+{clfPA;COCHMGUc*HV^_|+6%-bCB4YX={66mFE4j*WrQI2>Tf43zpe>)pft_e2 z&FuP^WxPw&S2xUAG90_j{@bnmNcOArBC?_H#Ce|po9ueZsvHz>h2}9Iqf6(ziO?ZE z+O~l9&5rdiArj$**6ZJTPM7t52CKUt+L0S!KGJw%O6~HY*513 zb;l?8fSHIh(OlzO>edH$>>L}LDD@uwHr)hg_HpI6Os`t3?;zRj1Gr*-@&1++aGEug z>ndToG0~om;V;v9cx3cl=lRh=df@GzrL6x@C?BehUP43z)O&HeUZs|uh3CVUcB6el z;Y4Pl77DNEfKL>j=loL36-0mCeezF|Ih+ri3RMS%SOB?uZsUm{or!P3>DsJt>DFmK2+j@i@8E z@6Z$uilGq{3Om8eYDhTr zydQ8h$~$2~-^QG0*8n+aRI}p3QA(_}Ye>?`Mal0L?cGO{iwgEql;8LTK)4QCD!fS| zw74MG^xUVaP(~5I5ch-hjEfzUyU@LQyAtyYE^QDRZCd#16w?IqYZo<~` z#@tuO^UGvG8|r5(N@%(XyWw_uxIGh53y&@&R>Bex$W}K9(gKDE&wr>%@ux(SYwzx< z43XlANXihwY~B}2U$o-|Dw|m&TZ_}&5ytBD&A_n?U!W$Ju2_}Ix5GHB{VdIJrQ6sACZ8&Wj{Ek$mX=R&ZSe<3E?MdE`LUTF|zXmdlX8KwU%}F0wFq(o6y%~ zbrp}hwOb&7xKQ6SqNeq z9Rza>^sPku1?2`zf8~~E>N=&gkji^~r%o``2-FS{6r}@&;8CX6^>FhO;S3%KQlWd# z9L2CL{^+MepI+<>V-6m9_NAOJ#dxH;Ysvf}d?~4D^Q#@{Q^fqusgR0qnw>fz9Ytu%dl%xIT+B~qCHpk6P|<=Iv_gEA;#V%ITI+%b zat-TyGTAr5OGK?1IDuESVRadu09&;61@aNq{WV+yA_At37Ot%){1UUVU8gJ-wxy@R z1qD-HuuYjmtrl~#-a*HYI(m{cy zyaqVTNaT3nc|;S{3At_pI9;}!s*1t%Vre|3}M@r0S4bVBvgD=kJ z9PM;poWM1X;({RqXZ2C+{~Gyw2F@L{fD)8XwE-*T)rz4M^ByQ%>Ln=$t#ncJOWc8c zAD_bu(x=K4&nTrRI5hxg9OVwgX~$j6KYPl2(eWROz~GY6Nh6c9ghuT#f1EKDbeS7@ zyQ1AoaTVdWNuAPWdxnMSG>1-t27Azwg+!Dl@DlaYzdk}rdFopk5INr1Y9~K2sUW*vW z!mcfOZD#r1>9HJXAGfkQGQT$GNDC=7eBE_`#P1CLKuNi4I+entM`&-LJGf9UlodiwSc7NBs#6F7Du~zVrkmtsBMXUpDmE&L^Kj zaHf9POiRt;dG&7PuQrBs(L!iGbnQU+h`Qfi&fbKq2-#|l#XGk2X&Ife7Sp1=7Y`)q zLihcAFs+q2E&{LqZhUk6M9J^6wT2|}sFF&vg}*Ogs1BI^Yb2)q=SN)wDz$2PyUrH% z%&o*`3cb*}cXmi+&%Os8E`m&X`HucLWoq8aLx+qz#_pY=L-oI_UEyH+s`d3VPtTDp za2^S$;ku*PE1!C-JuO3pLzf9~*=ZJ>CV6taS%fZFkH+{zGSqGp(bbScm%BwbQr2tp z{OJPUe-!b1j&4q9=QeBI+M113$FLIw67_B6+7^tr2g6yMOVx4?OVraA!WatIg2QKipq!+OR2yS3Yc&RUD9 z&PR6ZN-570-`LGgld=4|jdOG1n-Jch6C+R`0HhUda8MY5QH7lntJ!-62Lp-TwD_>4 zB|pj`dBB}D^s;y`2rO~+==Yl$Iw3G@XlrtMoN}n$?f$gq>>{&q$DMk(+izbBux-9> z9nCBoCMdHYjc;rvgSxO8A9?*m8zY`YTZpYa`9w1zswB@ zs60_2tJBcbj=0!Y)^vvK2~y@kzX7WzNP2iyaa78`NJ`f<@NjWn6RQ&eW;Y5dygm%* zJ-9fM(0#TDS0J@#clTq=wjT)YSCndv4`d%FRzhZXsN9lemk#VBZwoC2mZj&a+^ZS( z6fAqpkp?{+-)nvWu3l6Q+9D--_)#%&c{@ep=)NXFW_I^yN}l*!mUKWnU=c5X?3k(Y zc3oEF4qCMvSnWHF%)q2H;|V5r1rF};>poxB36?)iy18CYI~VL>|NYe=Aom^>s;Sgd wlB=gKO12cVdfuVo9^$pBkk1~`agY5v{zsLtHDMF_k0*wzqJ~1boO$^F0OX%vk^lez literal 0 HcmV?d00001 diff --git a/doc/images/microsoft-small.png b/doc/images/microsoft-small.png new file mode 100644 index 0000000000000000000000000000000000000000..56f3334e96c206f1a8ba02f04cba7dc0ef664ea7 GIT binary patch literal 8047 zcmV-#ACTaQP) zaB^>EX>4U6ba`-PAZ2)IW&i+q+O3&wa`dx`-%H<* z2PyA2ghQFv{Q9$<_xZ;0`t^bSK0nRx$3tK5>x}n>-fw(dF#DZ_U-o;E?+ayl9Z&DS zgUXL3kFS4siTa^BUvK^VyL5`NE7!{sm~mDa$qwE)~C#a z4O*P9OU#ic~GO2oY@i%A&-VoSNEmSPv_DN)A(KE}$Sx=avA zHB{NGQVP5WoFRU1KGV~4u2gv)MqU7-=9W9q&_IAyN-d3k8&JqX&9&58TkUmJt=4iY zt+v*BbDgrDd+D{e-uvjwU;}Xf`q&j9GCMAJ*cLX4`IA(FAQpAyS%%VyT zIagto)xp`x5d*<;lGK|%vHL>qKjIds?%%~tenT!ebpH?Jf%&Xj2>g(rIF#YSoOr$+{M=&3@0CbzM$BY32oM%Dr3&!8?(G%~5mitJHfDwPX943WJCal_^9%nX1wC^a2tvzp=eWh!0u$Ce!)EQ5 z!y#^p+uDE=eeGVC$Z@Z{f~t&KHr&y5`}8W-2|1+<%l!b?1+YESxx3T-xL4n)9`iDd za>m;C(e5D~KZk75)|w#Pma(*wdWUOuR<$|-dexc^!p6?Z@lyAhcjv!xN_7aJY3J%J zzLcG9!6uzGLZ>oR@I0eW$R5ht8oie1&AZKYITr1~>@UA2O1Ccjvl?b|a>jLHw;d(Z z=8Y|a&TM<{fkz$HTkY(-Qf>)iLNb}NG##|4$~ko@`REE&x^s?Wd8!>6gl*75cxGZF z^k(GkmVC$gWIOC1gH0#HHtxE{PAd$nmAU{^O+a*JhClJ_dT5iiSF?Bk?^G*&->HM8 zdqwL^mHSb4zGzgsf3-UJAWBiVfZqx4uKC_x@7~IVFExjq{hZ0tq~Co@hHMm;OWxXZ zPT`(%7o4IFko`<+W&7KiJ|%3pYf93eomS`1+k`sNjXHq3wzGw zu+}|Ggo&xWMWRk$Ef+)%>?qXO;ih%2d#fu^Ndx8_rv^y! z$7-H{5X=o7|`Y8|W!YHt|mBv=xnx zYf=o%Jhccj}uaavk9iXvAH8$Q%I^nF|y+&Okd2X*F-D!~{; zYvUKP65V@HBboF$Fw)13v+$}3Av>hqbA^uUMu?SNv*6-Ha3jz|K3v0K=p4Pv1yV*A zED>8vPze?zi3v%H>}6Y6T#vct3Z|CGW>jyX$EYNdig5M@%Na8e#0~h$E(>bIeQ+P3 zCDPCVeFOzBD6Xf#`!Y`=K&UwZij8v{`2gERNAEu9E6t3?20#KL)y^zNs9FQp%cxpB zk$4CbGXFWrPeT=H>{}=91RZz}si=gRo}US_9a=5zLRYwd!Hv?0l?!wMC?bx$M<{O; zPnJ{L+Z84i*apisz)Hq&p5lVlPE{HUufUJo))2BvtbwsO%Vpuj6v7Quo2%wx zphM|vVfkKBuQ#oLbdU=X#sf*v76#c~2zBAO_-TiJB%!pn5Bu{zeb^>yA>gSVdK8)I zdf*fdfM1mjhDBted3&V*NzBKC=ssS=u=hQvIkfjml7>law-H_U<`TQeZV3A&v?6eZ zhx(iqOMoH(a}>aMUm89$;W|}i)CO8|WkxmBnZ=}F38|aWa3&d=Zf|a6*lj>Cd-bLk z>SLs@K6bz0+u2fH%tTrj(Td7U66K+Z$$#|f*g(v*gjKteK>n=!eRYq@IUIr#O&z}y5T z+oLC*$3jDB`^5IUa>l}%EFkr-lyki|i0J+*uKU#dH3Uq{NF&SDjyT^-RCANSQ=Z_E zkp#H`)k4-*;|hoXIVm!cG=vUUVMr$x_(K^j<5FP1k0xC>vK#s$U-5-1Rkfzr|R1SET{o5|wIEmWPQ-Ay~yaA_BWy~pPQ(dRT{#TNx84bYXO zl3ydm-1Oswc%w6d%mAqe9=yrJ3{7d0+1-p#_8aLXI+@@FcW+%#zbnHru0XA`B>V|k zne_9`L_b7@CT;MrZ5TNTSewLxx;9ud1|b}vE+&EsU9^Z9Eqg|1%*=)i;VbomnooU* z3ceJx5ps|Wda~R^M8puEg)vnTMVchP<6&*CJ}5KGMnzPKH3+jJ9+ccO6V&vXfL@}C zXms-jFPEW}F&Irq(oB6&-#$5?M1b{Sp5a77o&*FVH=zgSqY?zm&yb@;2q?DuN}lE_ z66-PY;4TlMeZO%{Hds--jQBKIXO>M6K*;mnKcu=bC&*{G$$jYSO^jy|_?w`hQ5hKc zJx_0l&A84NlRkcr&!JI-QCB6b%sIWfF}5BCNpSbaTP}O6`rP>9(x%m!I+>0vV=&gWiSIWt4(qRaS`-MIp?5#gkCHcrgbLtj03i(yE##^zdPF`VI8Yx$Wh(5KP=PI_`<53i3K3IP;Mu&uZNO2IWLO^az0~ay=wVf~A;O z5jq`7$`J({_=oZi8aW7HVrfcsgdTeoBX-XGiOnJ zz@QtHnV*qV+`-B6{$P>(VLq=B-MsI$(8Oedz(j}FN~}Jn~HM`U>OU`F6^!`q~>#dJ)AF&TTRq7RBck~kVr zflZFXn_wBahX~9}gQMQ06r=6cPu4Bq!IiIMc7@yGCTH*p@30*vKq3U``ClTq93iDaVc^lf!LQl{?cq@vjeIx&3AB8oDV zJ@N_J%j+=(m3;}UXJVGi{77VIif|8aHl{E7xcI0@XRjENz8^o?!qXu1_WVNUZblzA ztL$f?0b3&a>;HF)u?vA+o=VLr0c$V>PyS*hdePQtv_n+f$+cFE2<-pMg8>fk64NmQ zWg;I89Eo&oV9QB-ypxOM za~JZ+@ew%Qal=Ptf51WBcvkRRB7GOVTqvf zngFj>0004mX+uL$Nkc;*aB^>EX>4Tx0C=2zkvmAkP!xv$rixN34i*$~$WWauh%X$q z3Pq?8YK2xEOm6yuCJjl7i=*ILaPYBMb#QUk)xlK|1Ro$Su1<dj$A?7vov}_x@Zx zYR+OnKqQ`JhG`RT5KnK~2Iqa^2rJ4e@j3CRNevP|a$WKGjdQ_efoDd{bZVYBLM#^A zSZQNcG&SN$;;5?WlrLmFRyl8R*2-1Z?32GRoYPm9xlT2N1QxLb2_h6!Q9>CuVzld| zSV+-%+{ZuY`XzEHFM$M+TUg_qxSG zH}S`=Mu>awdozO!V1Pm$cVWX1T{eouQiGM!HLO*NKSCMBK>-EDETDjbVixfK6$1tgC@5wD ze;!PoIOJ4bx4w*~_Yi$xZHH{N(7t5&W0aw}%o2f!7vv>Sl6B7&-RMaFa>6-(Oz+!wBtBC;Cz z3V?`A5s_zse=n$DZmeFt8e`1x97?$z&|JH8Sy@?P@#4kVFa3%cAeZJ1Fvk4A^Sl>2 zdN7qrp{hMa#C9ZCL;y-6c_Jbw7Syj}h{)8q<&@24mjmiJ&M7lz&ZMNIne>lDKy*S~jN zck$M(TbpWXY7Tjxx4opKq&KixL>>XI5RpEA310c-$z-zE(xpo~D>thv!aUDgt*Ym# z>OjYF#(JLDQdCq_P`!@PBXa%U0IOE5!uS1Xna1PJE6;&9fzU?j>p0HLj=XVW%wGjI zDyoXAu4`}050rDS=Xu*6d+ad)ZomEZ4ylo3GI?S$ne0N zPoLhAnbneGSNl9xrunEcdq+dbfnLCV5&4a(J`1e$Ja4x#rgz{N z6bCY2@jP#_wRTG2yp;kQUDv(8s;Y|W>S~h7WbZ^GaUC#DRnG(dIvC*p1h%N^V$bv5 z1QZr3X7c38xUP$}_EMk>7^8I}lWvbdNBI6=V$9CXXF@L5A`1?p7yd@&<1q-F9 z>g?lXPdcQk4~j_0H=HFR->$B%-V`mgV#Nw7Dk_|y;{fn|-*p^kR0p#E2AB{i!gk=E zFk{c2Jt1{8&>F&sE2xZMr>hHixYRa8{at5+{tT3TG7E3Ce@Gbq;rXXeT?#*743#h8KiFvfA5 zUl?P?_`dH4T;eTJr3|n?RPwl)QCC-owRR5hXcryLM}bp<-R4-Dh2jJqmd)cA?evyB zW%=EGkxV9Y2zV~iX;V_E)XAepwJwIY+;R)4RI12voWmi0s=7FxPB(VT=BTaZ(4j+D z0#%XZEU;8nF9l9h)vu}Q?Z6xD@v0(nRM+Kf4?GyN{XF1;fWf#> zRlV3N+-Ca@7m0yBVvk^J|8^5EUV-|U$4^K9UZRs}M)7bRLgUa7(_^RzMd;uQIZ zXc;K^&79A>+l0!>N>sI0L_$lvSVYFITet4#0|yS|nP;BCT6-p~@uvVys#=rRv=&(i zBtok|RqyaT?`eQD&Nzc@+qMBb1w55XrTTUt>R#7%=fzgbvS9TOvD9_lsQ?#Wd@=p| z_heMY8-UkC8!oV&7dwt4zzLD(U-vxkb%35ddp1s;I`ypu3l@}BRaH5W_i;HA z=&h=HyXSfT3ec}#KN=buK5I26?G7N?-lXfg(}8ybd-IXN#s>Ib*L4@h{Da=WKSbL1 zn(Mlcw;|Y1oUpry{G!eUk$t+gRB$K>?UHVy<)?W(?0MdM*4p*JwLwQu0Dk__M<0bf zeV2%YM){kL<8140UDn#c!1gYft+iLV zuKRLRQxnb2&7mFZPn|lIzJ2@l1+I&f|AVR?J|k>sXdtrFcp_2`d?QHuQAI^XRc&o; zQ@3^KD7r}XCR9Jkgcrc%O zE{NSfnd=OF!-fr=iF(6z-R1=g7SPz($i|HuNhA{efzu;p-fM1d-g)b-ZT+J)Yu1F9 zFUHz3GBAr?kJa^pwf0$K%r#Fu@dN;a2M=b+k|n6>H)CbjrPJx=Tn+YpzXkYhBhjsfPl>Md zs`_#_6?|`w@dDT03O( z=FMT==`qoNY*%Soxz)zld-Oflih10irlyAK>gt248l^{mluD%%vGd#dMT-{YJ0Z_# zJ1*MXO~q_)T)H$it5>gXSElQ_e^S*;fTwdjIv98@l}dfDtJ`ZQ@C%CBITl1_^r?YE zbXjEJo;n6~{62Es!*QI;k74^8V~$`+Wo4x#5{dY#9bWd%)wWTiM&Wthe%Ey;`o8}S z;QoLj3vq(uIIEM%WDgNJ5R=iTot+{)T09}PV|-dr%nmSp`gBxvXT-h|5vhr+0B^Un zw7ho=Y}r4?25dYW9mE_z?(3rL9l)O=NrMuJME`8ItqD70#*Fatn%MKN0jyiMj@xd# zjbX!vkxr-o$93I*%4V~JX$?vV(JzrmoDIAkYsM&Jj8jojk!v+B0=^qb-rn5Y{JUdm z!X0;txp3h^(&@AUKg;o+>ihnKH8nNv@vE(CexPBUcP)e#+VXg%;{Z~ zold8lRCQ4#=|mBk3vl6u7amT0{PD*z##}<%Xr7Jfbb5!iR;;zB4Hz(hl9D4-XeN`{ zrmFFH(w?sC?ua=w=cwvUwY9amY|J^pS&_b|nLmHNc2^lybx3$UWy%z$O`FDX8z$qr z?hDr14{00azKdh@-FiBm-eax32M7=29T8b+jHy>u7dYrRPLZlk2<+P#e1QPtY&Kh4 zR8({iM~)(=SZlumJTQFt@V%<4A~GSMm7K^R-xryTuQoL`eKcs$pr4+9{`p(>?c2Aj zPoF;Lipa3&Vy~)OgJoCTU&1N5mHir{MA~MNwoVOjv z`4IR(L>{1>BgghIdGchctE+d@8et!Gs;&qAOGG|&9Ookud6?Go_)sYy^gQpiS+fpD z^p%LnND+C>ahy*|OG`g>9A`^xvSvKb+g({%iRXE*0&}D91C%?Cvn8nST@fkc$e}_L zFv;`0o$)BBP9|Lw@KV%97m=lo<9rNk0Q!EJ6;s;tAm4rW-7v&#O>E1UJ8#}RR<2yx zWm)aP^E@({Otz?~=!djMH+A$KP_~o&HdwZ7S-9%g12+K&I-$}Uf!UttJv3|9EauOj zpUYG3LGX!4WGnqQRW%`_5!87v@Q=WuPPX9&RlVMI-IZl!W${?c4!!$KCZpMG_U<_4 z6b%6RGH>O}fnNtN{@Xb15^$7TthLi<4cge{dEN)@$XXxaKKI90S&n?KvM)Mdg9Z&^ z-n@B~mzPiK)vMP-v_=SB1AIG}2=9x?%c@#0A|C{u|9P+~6$NeI)ZPo0mzU$Z?sLYN zGk~i_?t zqOq~@ow~X@z3#f}c>ek4IrGdj89R0?uIoNvj9CEO3|uK9R{;rC-6;-89H=m&Nd~JNy*{Wx531T6LUUWSy}1CoS-iCXRfs9x8<{D z%}NAqJ92K%(+znC3L{*kr}TRV$v;q7uI*Z&Z~@0R3jc|tfP!KcP(VR33n(aN0R(zg!W@@Tys=K;x-+RwFJ@L)v~OEI&AyV>We`Jua~_wZ!y z`JMZF5xIit+8t%df*g@DeB@cU`(?Ygk5%}3;e~JbeX!i_L*k4)MhP>9y!@ckaRoA#M?l8s3k}OHf(cZ={J5O%@)76Xt9 z8MoDKcAUuF;Y+emQbac?3G&f7OHBRplKZ>1yw8=W@2l)q1;2kQUB5OQ1cLS+n6n7- zJ6;X1GM``OROQ`Fy2Hs(hH6?iLrORPr4wQ6&)Zy)OP=hi!b!7+MX$1rGelIH#AHo4 zZ%r%t^UP}aky97?&4evKSky>9~CIkK6bSh6SGd@|vXKCX(A)S%T$RathP& z;wG15-fiRWb*_DWVq`*njz=8X*cF1Z;#56@e^F-Bs-sS8qS|}Tdh5K2c_Vw>G~b-i zU25dNU5LIu`8|ccR3?k3>cc$hexyRKPNdfQ{>}pU{nxsQuYPC7!j5HMa$`{7TN_tJ zwadlz3j&K+RA+$qV1)u2=@sM1fr8QwT%*5(D)nzx(rPz(Pveb6iNB5XdAW{Vrhp*t zm9LWAIlxr@RqV`Xlm5+XG0|
Iv2uJ1J3dS!piU-GzX_t2g^?bD&6TJA+0y=^+4 zi7ALa7Q#yaySYD)pQbPw(f`CRb+A;G&-qdu-Aw)nI83SZ!!5oDE%RhWui$JmPX8no zn~e`SJ9#}UYVQiS{LJzE3#rAma#(on2EHX|z70>S#y&52S*Yx%<(W+P=noY^8Xf=@ z*fDaPUO?$lw^|;E+v$s|HvYz5JDxLAddUEp#7ENE;lW~8_|=4yuJ$51HE zz7p1^e%4su!){D{hPG0_KjPNnx;4ZCbNgkB}H6Siog)r%is1QJ|a(i zjbSc?`5Kyzb8Vzh9J6;6Aj*oY;<+VaA<4pq+ym`e76Cb$pu*qIwap3i=0^s2(E?0zC^}qBrWO zES2c~toLz_izXN~L?f#cZ@d`z9{9|NG}Nt~kKPeCxas4mw9Z7pkz{%bpX3!ULRdDr zpGYpgb&`qI>94_uc!CZJ??y!M~6Ltcq40Y^uk_U zXi`hSGj5S?5t#C*7_E#-R0^c>ASg;URGAbv*bnoVAfNTMijz zYilCJo=?AwwZCfA&G7@0J0;Xa_bPmI5I@><7%#8ztKyMZO%crnZrdct8FMrbv-r97 zFr%UB%SEzq5!b-|*0J9yd_UUj)CRtA_K-7HUzUsL?1EFCR}aoq*{<4BelOn`+#XgK z?edR&S|3DP8D)|;5%QR2I-k@*55_;GNoSyGsC-KX`*o`#!`7wL7`3NL&Fyj*lxH5!K zMp_H+quA=%DJ3ul6bv@Uy=sSYZdXHy0p&t|^psemXOd^Z_@LTPdO8LMQDHyD!8_!n zdi9IgOA4B#SkB&WS|J|Qmq&%5q9`_!d=2qCwU0i!D<2CCK;|h^11px&0`F3MEH|?l ztXHMk6$n3K=Av~*7Z1`m|D4S*r}y&PeRtO9=Vx|Jz-irb`U7i4BSht5L#3oGtV~{@ zS~9}R7-~3jhza$2!L>lViI`Jo%YTO#ulW1g+3f<;WZ;8CAylV+5q(==k>6zzwgDX& z9eBf$u5ie}hPT_nniuu;{WunvgLv-#)1tOqop!uqivyv29G>thLTk;|&z&wF zJeFBFEH%2xb*`$R>L~3qx@uyHiOLH51zsFa@yz->=~Os8W&s(^vNSmMP01I>5u4@8 z>4TMXlDEiaK|VGyCYW+Ry0l`Y4mnadAjHZZPW%+9Fb+RF)csrhQk1xyijcxc$eh_1 zLP4*of?-xpqbMY?Bh0q=49}hI=sHp3gv)Vvh<)IK2YZ#K3ImN^k~2KS-GVosr$*P#p8_7$ME60tMrfOeUZ~{anYi8PSf zVRKy7BthTG_t@|5(VU~iR=MXfTP`Ic{DEM)kdtFQHjI-Bs)XAMw>Q-HCQd@ zC9IbDK}#1G&Ctx{SOODXbSqU=3?lQlIovc