dimanche 12 juin 2016

K-means for color quantization - Code not vectorized

I'm doing this exercise by Andrew NG about using k-means to reduce the number of colors in an image. It worked correctly but I'm afraid it's a little slow because of all the for loops in the code, so I'd like to vectorize them. But there are those loops that I just can't seem to vectorize effectively. Please help me, thank you very much!

Also if possible please give some feedback on my coding style :)

Here is the link of the exercise, and here is the dataset. The correct result is given in the link of the exercise.

And here is my code:

function [] = KMeans()

    Image = double(imread('bird_small.tiff'));
    [rows,cols, RGB] = size(Image);
    Points = reshape(Image,rows * cols, RGB);
    K = 16;
    Centroids = zeros(K,RGB);    
    s = RandStream('mt19937ar','Seed',0);
    % Initialization :
    % Pick out K random colours and make sure they are all different
    % from each other! This prevents the situation where two of the means
    % are assigned to the exact same colour, therefore we don't have to 
    % worry about division by zero in the E-step 
    % However, if K = 16 for example, and there are only 15 colours in the
    % image, then this while loop will never exit!!! This needs to be
    % addressed in the future :( 
    % TODO : Vectorize this part!
    done = false;
    while done == false
        RowIndex = randperm(s,rows);
        ColIndex = randperm(s,cols);
        RowIndex = RowIndex(1:K);
        ColIndex = ColIndex(1:K);
        for i = 1 : K
            for j = 1 : RGB
                Centroids(i,j) = Image(RowIndex(i),ColIndex(i),j);
            end
        end
        Centroids = sort(Centroids,2);
        Centroids = unique(Centroids,'rows'); 
        if size(Centroids,1) == K
            done = true;
        end
    end;
%     imshow(imread('bird_small.tiff'))
%    
%     for i = 1 : K
%         hold on;
%         plot(RowIndex(i),ColIndex(i),'r+','MarkerSize',50)
%     end



    eps = 0.01; % Epsilon
    IterNum = 0;
    while 1
        % E-step: Estimate membership given parameters 
        % Membership: The centroid that each colour is assigned to
        % Parameters: Location of centroids
        Dist = pdist2(Points,Centroids,'euclidean');

        [~, WhichCentroid] = min(Dist,[],2);

        % M-step: Estimate parameters given membership
        % Membership: The centroid that each colour is assigned to
        % Parameters: Location of centroids
        % TODO: Vectorize this part!
        OldCentroids = Centroids;
        for i = 1 : K
            PointsInCentroid = Points((find(WhichCentroid == i))',:);
            NumOfPoints = size(PointsInCentroid,1);
            % Note that NumOfPoints is never equal to 0, as a result of
            % the initialization. Or .... ???????
            if NumOfPoints ~= 0 
                Centroids(i,:) = sum(PointsInCentroid , 1) / NumOfPoints ;
            end
        end    

        % Check for convergence: Here we use the L2 distance
        IterNum = IterNum + 1;
        Margins = sqrt(sum((Centroids - OldCentroids).^2, 2));
        if sum(Margins > eps) == 0
            break;
        end

    end
    IterNum;
    Centroids ;


    % Load the larger image
    [LargerImage,ColorMap] = imread('bird_large.tiff');
    LargerImage = double(LargerImage);
    [largeRows,largeCols,NewRGB] = size(LargerImage);  % RGB is always 3     
    % TODO: Vectorize this part!    
    largeRows
    largeCols
    NewRGB
    % Replace each of the pixel with the nearest centroid    
    NewPoints = reshape(LargerImage,largeRows * largeCols, NewRGB);
    Dist = pdist2(NewPoints,Centroids,'euclidean');
    [~,WhichCentroid] = min(Dist,[],2);
    NewPoints = Centroids(WhichCentroid,:);
    LargerImage = reshape(NewPoints,largeRows,largeCols,NewRGB);

%     for i = 1 : largeRows 
%         for j = 1 : largeCols
%             Dist = pdist2(Centroids,reshape(LargerImage(i,j,:),1,RGB),'euclidean');
%             [~,WhichCentroid] = min(Dist);    
%             LargerImage(i,j,:) = Centroids(WhichCentroid,:);            
%         end
%     end

    % Display new image
    imshow(uint8(round(LargerImage)),ColorMap)

UPDATE: Replaced

for i = 1 : K
            for j = 1 : RGB
                Centroids(i,j) = Image(RowIndex(i),ColIndex(i),j);
            end
        end

with

for i = 1 : K
            Centroids(i,:) = Image(RowIndex(i),ColIndex(i),:);
        end

I think this may be vectorized further by using linear indexing, but for now I should just focus on the while loop since it takes most of the time. Also when I tried @Dev-iL's suggestion and replaced

for i = 1 : K
        PointsInCentroid = Points((find(WhichCentroid == i))',:);
        NumOfPoints = size(PointsInCentroid,1);
        % Note that NumOfPoints is never equal to 0, as a result of
        % the initialization. Or .... ???????
        if NumOfPoints ~= 0 
            Centroids(i,:) = sum(PointsInCentroid , 1) / NumOfPoints ;
        end
    end    

with

E = sparse(1:size(WhichCentroid), WhichCentroid' , 1, Num, K, Num);
Centroids = (E * spdiags(1./sum(E,1)',0,K,K))' * Points ;

the results were always worse: With K = 16, the first takes 2,414s , the second takes 2,455s ; K = 32, the first takes 4,529s , the second takes 5,022s. Seems like vectorization does not help, but maybe there's something wrong with my code :( .

Aucun commentaire:

Enregistrer un commentaire