1+ # -*- coding: utf-8 -*-
12"""Algorithms for spectral clustering"""
23
34# Author: Gael Varoquaux [email protected] 1112from ..base import BaseEstimator , ClusterMixin
1213from ..utils import check_random_state , as_float_array , deprecated
1314from ..utils .extmath import norm
14- from ..metrics .pairwise import rbf_kernel
15+ from ..metrics .pairwise import pairwise_kernels
1516from ..neighbors import kneighbors_graph
1617from ..manifold import spectral_embedding
1718from .k_means_ import k_means
@@ -287,8 +288,9 @@ class SpectralClustering(BaseEstimator, ClusterMixin):
287288 If affinity is the adjacency matrix of a graph, this method can be
288289 used to find normalized graph cuts.
289290
290- When calling ``fit``, an affinity matrix is constructed using either the
291- Gaussian (aka RBF) kernel of the euclidean distanced ``d(X, X)``::
291+ When calling ``fit``, an affinity matrix is constructed using either
292+ kernel function such the Gaussian (aka RBF) kernel of the euclidean
293+ distanced ``d(X, X)``::
292294
293295 np.exp(-gamma * d(X,X) ** 2)
294296
@@ -302,12 +304,27 @@ class SpectralClustering(BaseEstimator, ClusterMixin):
302304 n_clusters : integer, optional
303305 The dimension of the projection subspace.
304306
305- affinity: string, 'nearest_neighbors', 'rbf' or 'precomputed'
307+ affinity : string, array-like or callable, default 'rbf'
308+ If a string, this may be one of 'nearest_neighbors', 'precomputed',
309+ 'rbf' or one of the kernels supported by
310+ `sklearn.metrics.pairwise_kernels`.
311+
312+ Only kernels that produce similarity scores (non-negative values that
313+ increase with similarity) should be used. This property is not checked
314+ by the clustering algorithm.
306315
307316 gamma: float
308- Scaling factor of Gaussian (rbf) affinity kernel. Ignored for
317+ Scaling factor of RBF, polynomial, exponential chi² and
318+ sigmoid affinity kernel. Ignored for
309319 ``affinity='nearest_neighbors'``.
310320
321+ degree : float, default=3
322+ Degree of the polynomial kernel. Ignored by other kernels.
323+
324+ coef0 : float, default=1
325+ Zero coefficient for polynomial and sigmoid kernels.
326+ Ignored by other kernels.
327+
311328 n_neighbors: integer
312329 Number of neighbors to use when constructing the affinity matrix using
313330 the nearest neighbors method. Ignored for ``affinity='rbf'``.
@@ -338,6 +355,10 @@ class SpectralClustering(BaseEstimator, ClusterMixin):
338355 also be sensitive to initialization. Discretization is another approach
339356 which is less sensitive to random initialization.
340357
358+ kernel_params : dictionary of string to any, optional
359+ Parameters (keyword arguments) and values for kernel passed as
360+ callable object. Ignored by other kernels.
361+
341362 Attributes
342363 ----------
343364 `affinity_matrix_` : array-like, shape (n_samples, n_samples)
@@ -381,7 +402,8 @@ class SpectralClustering(BaseEstimator, ClusterMixin):
381402
382403 def __init__ (self , n_clusters = 8 , eigen_solver = None , random_state = None ,
383404 n_init = 10 , gamma = 1. , affinity = 'rbf' , n_neighbors = 10 , k = None ,
384- eigen_tol = 0.0 , assign_labels = 'kmeans' , mode = None ):
405+ eigen_tol = 0.0 , assign_labels = 'kmeans' , mode = None ,
406+ degree = 3 , coef0 = 1 , kernel_params = None ):
385407 if k is not None :
386408 warnings .warn ("'k' was renamed to n_clusters and "
387409 "will be removed in 0.15." ,
@@ -402,6 +424,9 @@ def __init__(self, n_clusters=8, eigen_solver=None, random_state=None,
402424 self .n_neighbors = n_neighbors
403425 self .eigen_tol = eigen_tol
404426 self .assign_labels = assign_labels
427+ self .degree = degree
428+ self .coef0 = coef0
429+ self .kernel_params = kernel_params
405430
406431 def fit (self , X ):
407432 """Creates an affinity matrix for X using the selected affinity,
@@ -419,18 +444,22 @@ def fit(self, X):
419444 " a custom affinity matrix, "
420445 "set ``affinity=precomputed``." )
421446
422- if self .affinity == 'rbf' :
423- self .affinity_matrix_ = rbf_kernel (X , gamma = self .gamma )
424-
425- elif self .affinity == 'nearest_neighbors' :
447+ if self .affinity == 'nearest_neighbors' :
426448 connectivity = kneighbors_graph (X , n_neighbors = self .n_neighbors )
427449 self .affinity_matrix_ = 0.5 * (connectivity + connectivity .T )
428450 elif self .affinity == 'precomputed' :
429451 self .affinity_matrix_ = X
430452 else :
431- raise ValueError ("Invalid 'affinity'. Expected 'rbf', "
432- "'nearest_neighbors' or 'precomputed', got '%s'."
433- % self .affinity )
453+ params = self .kernel_params
454+ if params is None :
455+ params = {}
456+ if not callable (self .affinity ):
457+ params ['gamma' ] = self .gamma
458+ params ['degree' ] = self .degree
459+ params ['coef0' ] = self .coef0
460+ self .affinity_matrix_ = pairwise_kernels (X , metric = self .affinity ,
461+ filter_params = True ,
462+ ** params )
434463
435464 random_state = check_random_state (self .random_state )
436465 self .labels_ = spectral_clustering (self .affinity_matrix_ ,
@@ -457,3 +486,5 @@ def mode(self):
457486 " 0.15." )
458487 def k (self ):
459488 return self .n_clusters
489+
490+
0 commit comments