I'm doing this exercise by Andrew NG about using k-means to reduce the number of colors in an image. It worked correctly but I'm afraid it's a little slow because of all the for loops in the code, so I'd like to vectorize them. But there are those loops that I just can't seem to vectorize effectively. Please help me, thank you very much!
Also if possible please give some feedback on my coding style :)
Here is the link of the exercise, and here is the dataset. The correct result is given in the link of the exercise.
And here is my code:
function [] = KMeans()
Image = double(imread('bird_small.tiff'));
[rows,cols, RGB] = size(Image);
Points = reshape(Image,rows * cols, RGB);
K = 16;
Centroids = zeros(K,RGB);
s = RandStream('mt19937ar','Seed',0);
% Initialization :
% Pick out K random colours and make sure they are all different
% from each other! This prevents the situation where two of the means
% are assigned to the exact same colour, therefore we don't have to
% worry about division by zero in the E-step
% However, if K = 16 for example, and there are only 15 colours in the
% image, then this while loop will never exit!!! This needs to be
% addressed in the future :(
% TODO : Vectorize this part!
done = false;
while done == false
RowIndex = randperm(s,rows);
ColIndex = randperm(s,cols);
RowIndex = RowIndex(1:K);
ColIndex = ColIndex(1:K);
for i = 1 : K
for j = 1 : RGB
Centroids(i,j) = Image(RowIndex(i),ColIndex(i),j);
end
end
Centroids = sort(Centroids,2);
Centroids = unique(Centroids,'rows');
if size(Centroids,1) == K
done = true;
end
end;
% imshow(imread('bird_small.tiff'))
%
% for i = 1 : K
% hold on;
% plot(RowIndex(i),ColIndex(i),'r+','MarkerSize',50)
% end
eps = 0.01; % Epsilon
IterNum = 0;
while 1
% E-step: Estimate membership given parameters
% Membership: The centroid that each colour is assigned to
% Parameters: Location of centroids
Dist = pdist2(Points,Centroids,'euclidean');
[~, WhichCentroid] = min(Dist,[],2);
% M-step: Estimate parameters given membership
% Membership: The centroid that each colour is assigned to
% Parameters: Location of centroids
% TODO: Vectorize this part!
OldCentroids = Centroids;
for i = 1 : K
PointsInCentroid = Points((find(WhichCentroid == i))',:);
NumOfPoints = size(PointsInCentroid,1);
% Note that NumOfPoints is never equal to 0, as a result of
% the initialization. Or .... ???????
if NumOfPoints ~= 0
Centroids(i,:) = sum(PointsInCentroid , 1) / NumOfPoints ;
end
end
% Check for convergence: Here we use the L2 distance
IterNum = IterNum + 1;
Margins = sqrt(sum((Centroids - OldCentroids).^2, 2));
if sum(Margins > eps) == 0
break;
end
end
IterNum;
Centroids ;
% Load the larger image
[LargerImage,ColorMap] = imread('bird_large.tiff');
LargerImage = double(LargerImage);
[largeRows,largeCols,NewRGB] = size(LargerImage); % RGB is always 3
% TODO: Vectorize this part!
largeRows
largeCols
NewRGB
% Replace each of the pixel with the nearest centroid
NewPoints = reshape(LargerImage,largeRows * largeCols, NewRGB);
Dist = pdist2(NewPoints,Centroids,'euclidean');
[~,WhichCentroid] = min(Dist,[],2);
NewPoints = Centroids(WhichCentroid,:);
LargerImage = reshape(NewPoints,largeRows,largeCols,NewRGB);
% for i = 1 : largeRows
% for j = 1 : largeCols
% Dist = pdist2(Centroids,reshape(LargerImage(i,j,:),1,RGB),'euclidean');
% [~,WhichCentroid] = min(Dist);
% LargerImage(i,j,:) = Centroids(WhichCentroid,:);
% end
% end
% Display new image
imshow(uint8(round(LargerImage)),ColorMap)
UPDATE: Replaced
for i = 1 : K
for j = 1 : RGB
Centroids(i,j) = Image(RowIndex(i),ColIndex(i),j);
end
end
with
for i = 1 : K
Centroids(i,:) = Image(RowIndex(i),ColIndex(i),:);
end
I think this may be vectorized further by using linear indexing, but for now I should just focus on the while loop since it takes most of the time. Also when I tried @Dev-iL's suggestion and replaced
for i = 1 : K
PointsInCentroid = Points((find(WhichCentroid == i))',:);
NumOfPoints = size(PointsInCentroid,1);
% Note that NumOfPoints is never equal to 0, as a result of
% the initialization. Or .... ???????
if NumOfPoints ~= 0
Centroids(i,:) = sum(PointsInCentroid , 1) / NumOfPoints ;
end
end
with
E = sparse(1:size(WhichCentroid), WhichCentroid' , 1, Num, K, Num);
Centroids = (E * spdiags(1./sum(E,1)',0,K,K))' * Points ;
the results were always worse: With K = 16, the first takes 2,414s , the second takes 2,455s ; K = 32, the first takes 4,529s , the second takes 5,022s. Seems like vectorization does not help, but maybe there's something wrong with my code :( .
Aucun commentaire:
Enregistrer un commentaire