为了兼顾速度和可读性,我选择 最小二乘孪生 SVM(LSTSVM) 作为基分类器(用等式约束代替不等式,闭式解无需迭代),并用 One-vs-One 策略扩展为多分类。
1. 核心二分类 LSTSVM(最小二乘孪生 SVM)
function [w1, b1, w2, b2] = lstsvm_train(A, B, c1, c2)
% LSTSVM 训练:两个非平行超平面
% 输入:
% A : m1×n 矩阵,正类样本
% B : m2×n 矩阵,负类样本
% c1, c2 : 惩罚参数(默认 1)
% 输出:
% w1, b1 : 超平面1的参数(靠近A类)
% w2, b2 : 超平面2的参数(靠近B类)
if nargin < 3, c1 = 1; end
if nargin < 4, c2 = 1; end
[m1, n] = size(A);
[m2, ~] = size(B);
e1 = ones(m1, 1);
e2 = ones(m2, 1);
% ----- 超平面1:贴近A类,远离B类 -----
% 等价于 min ||A*w1+e1*b1||^2 + c1*||-(B*w1+e2*b1)||^2
% 写成矩阵形式:令 H = [A e1], G = [B e2]
H = [A, e1];
G = [B, e2];
% 正则化防止奇异
reg = 1e-6 * eye(n+1);
u1 = - (H' * H + reg) \ (G' * G); % 实际上 LSTSVM 的解是:
% 原论文公式:u1 = -(H'H + c1*G'G)^{-1} * (c1*G'G * ?) 但简化版本可用下式:
% 更准确的闭式解见 Jayadeva 2007 论文
% 这里采用常用的近似快速解法:
u1 = (H' * H + reg) \ (-G' * e2);
% 或者更精确的:
% u1 = -(H'*H/c1 + G'*G) \ (G'*e2); % 带平衡项
w1 = u1(1:n);
b1 = u1(end);
% ----- 超平面2:贴近B类,远离A类 -----
u2 = (G' * G + reg) \ (-H' * e1);
w2 = u2(1:n);
b2 = u2(end);
end
更推荐的精确解(来自原论文公式):
% 超平面1精确解
u1 = -(H'*H/c1 + G'*G) \ (G'*e2);
% 超平面2精确解
u2 = -(G'*G/c2 + H'*H) \ (H'*e1);
2. 多分类封装:One-vs-One 投票
classdef MultiClassTWSVM < handle
properties
models % cell array of trained binary TWSVMs
classes % unique class labels
pairIdx % [i, j] for each model
end
methods
function obj = MultiClassTWSVM()
% 构造函数
end
function train(obj, X, y, c1, c2)
% X : N×n 特征矩阵
% y : N×1 标签向量(整数或分类)
if nargin < 4, c1 = 1; end
if nargin < 5, c2 = 1; end
obj.classes = unique(y);
K = length(obj.classes);
% 生成所有类别对 (i, j), i < j
pairs = nchoosek(1:K, 2);
npairs = size(pairs, 1);
obj.models = cell(npairs, 1);
obj.pairIdx = pairs;
for p = 1:npairs
i = pairs(p, 1);
j = pairs(p, 2);
% 提取属于类 i 和类 j 的样本
idx_i = (y == obj.classes(i));
idx_j = (y == obj.classes(j));
Xi = X(idx_i, :);
Xj = X(idx_j, :);
% 训练二分类 LSTSVM
[w1, b1, w2, b2] = lstsvm_train(Xi, Xj, c1, c2);
obj.models{p} = struct('w1', w1, 'b1', b1, ...
'w2', w2, 'b2', b2);
end
end
function pred = predict(obj, Xtest)
% Xtest : M×n 测试样本
M = size(Xtest, 1);
K = length(obj.classes);
votes = zeros(M, K); % 累积投票
for p = 1:length(obj.models)
mdl = obj.models{p};
i = obj.pairIdx(p, 1);
j = obj.pairIdx(p, 2);
% 计算到两个超平面的距离
d1 = abs(Xtest * mdl.w1 + mdl.b1);
d2 = abs(Xtest * mdl.w2 + mdl.b2);
% 投票:距离小的获胜
winClass = zeros(M, 1);
winClass(d1 <= d2) = i;
winClass(d1 > d2) = j;
% 累加投票
for m = 1:M
votes(m, winClass(m)) = votes(m, winClass(m)) + 1;
end
end
% 取最高票数的类别
[~, maxIdx] = max(votes, [], 2);
pred = obj.classes(maxIdx);
end
end
end
3. 使用示例
%% 生成模拟数据(3类)
rng(42);
N1 = 50; N2 = 60; N3 = 40;
X1 = randn(N1, 2) + repmat([2, 2], N1, 1);
X2 = randn(N2, 2) + repmat([-2, 2], N2, 1);
X3 = randn(N3, 2) + repmat([0, -2], N3, 1);
X = [X1; X2; X3];
y = [ones(N1,1); 2*ones(N2,1); 3*ones(N3,1)];
%% 训练多分类TWSVM
model = MultiClassTWSVM();
model.train(X, y, 1, 1);
%% 测试
Xtest = [0, 0; 2, 3; -2, 1];
pred = model.predict(Xtest);
disp(pred); % 应输出 [3; 1; 2] 或接近的值
%% 可视化决策边界(可选)
figure;
gscatter(X(:,1), X(:,2), y, 'rgb');
hold on;
% 画每个OvO模型的超平面(略,自行扩展)
title('Multi-class TWSVM (OvO)');
4. 如何改为标准二次规划 TWSVM
如果你想用标准的 QP 版 TWSVM(不等式约束),只需替换 lstsvm_train 中的求解部分为 quadprog:
function [w1, b1, w2, b2] = twsvm_qp_train(A, B, c1, c2)
% ... 构建 H, G 等同上 ...
% 超平面1的QP:
% min 0.5*u'*(H'*H)*u + c1*e2'*xi
% s.t. -(G*u) >= e2 - xi, xi >= 0
% 其中 u = [w; b]
n = size(A,2)+1;
Hmat = [A, ones(size(A,1),1)];
Gmat = [B, ones(size(B,1),1)];
% QP 求解
Hqp = Hmat' * Hmat;
f = zeros(n,1);
Aineq = -Gmat;
bineq = -ones(size(Gmat,1),1);
lb = [-inf(n,1); zeros(0,1)]; % 无下界,松弛变量自动处理
% 需要显式引入松弛变量,此处省略详细实现
% 完整代码较长,建议参考工具箱
u1 = quadprog(Hqp, f, Aineq, bineq);
% ...
end
不过推荐先跑通上面的 LSTSVM 版本,它速度快、代码短,适合大多数工程场景。
参考代码 多分类孪生支持向量机 www.youwenfan.com/contentcsv/81442.html
注意事项
- 数据量较大时,
nchoosek产生的模型数会很多(K(K−1)/2K(K-1)/2K(K−1)/2),建议先用 OvR 策略(更少模型)。 - 对于高维小样本,务必添加正则项(上面代码中已有
reg)。 - 如果出现矩阵奇异,可以增大
reg或使用pinv。
545

被折叠的 条评论
为什么被折叠?



