@@ -1101,14 +1101,18 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
11011101 coef_ : array, shape (1, n_features) or (n_classes, n_features)
11021102 Coefficient of the features in the decision function.
11031103
1104- `coef_` is of shape (1, n_features) when the given problem
1105- is binary.
1104+ `coef_` is of shape (1, n_features) when the given problem is binary.
1105+ In particular, when `multi_class='multinomial'`, `coef_` corresponds
1106+ to outcome 1 (True) and `-coef_` corresponds to outcome 0 (False).
11061107
11071108 intercept_ : array, shape (1,) or (n_classes,)
11081109 Intercept (a.k.a. bias) added to the decision function.
11091110
11101111 If `fit_intercept` is set to False, the intercept is set to zero.
1111- `intercept_` is of shape(1,) when the problem is binary.
1112+ `intercept_` is of shape (1,) when the given problem is binary.
1113+ In particular, when `multi_class='multinomial'`, `intercept_`
1114+ corresponds to outcome 1 (True) and `-intercept_` corresponds to
1115+ outcome 0 (False).
11121116
11131117 n_iter_ : array, shape (n_classes,) or (1, )
11141118 Actual number of iterations for all classes. If binary or multinomial,
@@ -1332,11 +1336,17 @@ def predict_proba(self, X):
13321336 """
13331337 if not hasattr (self , "coef_" ):
13341338 raise NotFittedError ("Call fit before prediction" )
1335- calculate_ovr = self .coef_ .shape [0 ] == 1 or self .multi_class == "ovr"
1336- if calculate_ovr :
1339+ if self .multi_class == "ovr" :
13371340 return super (LogisticRegression , self )._predict_proba_lr (X )
13381341 else :
1339- return softmax (self .decision_function (X ), copy = False )
1342+ decision = self .decision_function (X )
1343+ if decision .ndim == 1 :
1344+ # Workaround for multi_class="multinomial" and binary outcomes
1345+ # which requires softmax prediction with only a 1D decision.
1346+ decision_2d = np .c_ [- decision , decision ]
1347+ else :
1348+ decision_2d = decision
1349+ return softmax (decision_2d , copy = False )
13401350
13411351 def predict_log_proba (self , X ):
13421352 """Log of probability estimates.
0 commit comments