I write a neural network program with matlab and after profiling my code I understand this function is bottleneck of my code. In this funcion I do a lot of sequential Matrix operation. is there any way to speedup this function?
p.s. this function is invoked million times in a for loop and when I convert this function to run in matlab GPU it run several times slower than CPU
Thank you
function Backward_DBLSTM (tmp_input,Current_Layer,block_size,Flag,Is_Last)
global alfa;
global Pre_Cell_State;
global Cell_State;
global Pre_Y_Cell;
global Output_Delta;
global Y_InGate;
global Y_FGate;
global net_Cell;
global Y_OutGate;
global OutGate_Delta;
global State_Error;
global W_F_OutGate;
global W_B_OutGate;
global W_F_Output;
global W_B_Output;
global W_F_Cell;
global W_B_Cell;
global W_F_FGate;
global W_B_FGate;
global W_F_InGate;
global W_B_InGate;
global deltaW_Input_Cell;
global deltaW_Input_FGate;
global deltaW_Input_InGate;
global deltaW_Input_OutGate ;
global deltaW_Cell_InGate ;
global deltaW_Cell_FGate ;
global deltaW_Cell_OutGate;
global deltaW_Cell_Cell;
global deltaW_F_Cell;
global deltaW_B_Cell;
global deltaW_F_InGate;
global deltaW_B_InGate;
global deltaW_F_FGate;
global deltaW_B_FGate;
global deltaW_F_OutGate;
global deltaW_B_OutGate;
global tmp_deltaW_Input_Cell;
global tmp_deltaW_Input_FGate;
global tmp_deltaW_Input_InGate;
global tmp_deltaW_Input_OutGate ;
global tmp_deltaW_Cell_InGate ;
global tmp_deltaW_Cell_FGate ;
global tmp_deltaW_Cell_OutGate;
global tmp_deltaW_Cell_Cell;
global tmp_deltaW_F_Cell;
global tmp_deltaW_B_Cell;
global tmp_deltaW_F_InGate;
global tmp_deltaW_B_InGate;
global tmp_deltaW_F_FGate;
global tmp_deltaW_B_FGate;
global tmp_deltaW_F_OutGate;
global tmp_deltaW_B_OutGate;
global Pre_Rond_Cell2Cell;
global Pre_Rond_Cell2InGate;
global Pre_Rond_Cell2FGate;
global Pre_Rond_Input2InGate;
global Pre_Rond_Input2Cell;
global Pre_Rond_Input2FGate;
global Pre_Rond_F2InGate ;
global Pre_Rond_F2Cell ;
global Pre_Rond_F2FGate ;
global Pre_Rond_B2InGate;
global Pre_Rond_B2Cell;
global Pre_Rond_B2FGate;
global Rond_Cell2Cell;
global Rond_Cell2InGate;
global Rond_Cell2FGate;
global Rond_Input2InGate;
global Rond_Input2Cell;
global Rond_Input2FGate;
global Rond_F2InGate ;
global Rond_F2Cell ;
global Rond_F2FGate ;
global Rond_B2InGate;
global Rond_B2Cell;
global Rond_B2FGate;
% a1=rand(120,120,3,12);
% b1=rand(120,120,3,12);
% c=rand(120,120,3,12);
% d=rand(120,120,3,12);
% e=rand(120,120,3,12);
%%
%************Calculate Delta of Output Gates
if isequal(Is_Last,1)
if isequal(Flag,1)
tmp= Output_Delta*W_F_Output';
elseif isequal(Flag,2)
tmp= Output_Delta* W_B_Output';
end
OutGate_Delta{Flag,Current_Layer}=( Y_OutGate{Flag,Current_Layer}.*(1-Y_OutGate{Flag,Current_Layer}) ) .* h_func(Cell_State{Flag,Current_Layer}).*tmp ;
else
if isequal(Flag,1)
a=(OutGate_Delta{1,Current_Layer+1}*W_F_OutGate{1,Current_Layer+1}')+(OutGate_Delta{2,Current_Layer+1}*W_F_OutGate{2,Current_Layer+1}');
% b=diag((Rond_F2Cell{1,Current_Layer+1}*W_F_Cell{1,Current_Layer+1}')+(Rond_F2Cell{2,Current_Layer+1}*W_F_Cell{2,Current_Layer+1}'))';
% c=diag((Rond_F2FGate{1,Current_Layer+1}*W_F_FGate{1,Current_Layer+1}')+(Rond_F2FGate{2,Current_Layer+1}*W_F_FGate{2,Current_Layer+1}'))';
% d=diag((Rond_F2InGate{1,Current_Layer+1}*W_F_InGate{1,Current_Layer+1}')+(Rond_F2InGate{2,Current_Layer+1}*W_F_InGate{2,Current_Layer+1}'))';
tmp=a;
elseif isequal(Flag,2)
a=(OutGate_Delta{1,Current_Layer+1}*W_B_OutGate{1,Current_Layer+1}')+(OutGate_Delta{2,Current_Layer+1}*W_B_OutGate{2,Current_Layer+1}');
% b=diag((Rond_B2Cell{1,Current_Layer+1}*W_B_Cell{1,Current_Layer+1}')+(Rond_B2Cell{2,Current_Layer+1}*W_B_Cell{2,Current_Layer+1}'))';
% c=diag((Rond_B2FGate{1,Current_Layer+1}*W_B_FGate{1,Current_Layer+1}')+(Rond_B2FGate{2,Current_Layer+1}*W_B_FGate{2,Current_Layer+1}'))';
% d=diag((Rond_B2InGate{1,Current_Layer+1}*W_B_InGate{1,Current_Layer+1}')+(Rond_B2InGate{2,Current_Layer+1}*W_B_InGate{2,Current_Layer+1}'))';
tmp=a;
end
OutGate_Delta{Flag,Current_Layer}=( Y_OutGate{Flag,Current_Layer}.*(1-Y_OutGate{Flag,Current_Layer}) ) .* h_func(Cell_State{Flag,Current_Layer}).*tmp ;
end
%***********Calculate States Error
State_Error{Flag,Current_Layer}= Y_OutGate{Flag,Current_Layer}.*h_prime(Cell_State{Flag,Current_Layer}).*tmp ;
%**************************************************
if isequal(Current_Layer,1)
% **********Calculate Rond of Input's to Cells
tmp_1=repmat(Y_FGate{Flag,Current_Layer},length(tmp_input),1);
tmp_2=repmat(g_prime(net_Cell{Flag,Current_Layer}),length(tmp_input),1);
tmp_3=repmat(Y_InGate{Flag,Current_Layer},length(tmp_input),1);
tmp_4=repmat(tmp_input',1,block_size);
Rond_Input2Cell{Flag,Current_Layer}= (Pre_Rond_Input2Cell{Flag,Current_Layer}.*tmp_1)+(tmp_2.*tmp_3.*tmp_4) ;
else
% **********Calculate Rond of Forward to Cells
tmp_1=repmat(Y_FGate{Flag,Current_Layer},block_size,1);
tmp_2=repmat(g_prime(net_Cell{Flag,Current_Layer}),block_size,1);
tmp_3=repmat(Y_InGate{Flag,Current_Layer},block_size,1);
tmp_4=repmat(Pre_Y_Cell{1,Current_Layer-1}',1,block_size);
Rond_F2Cell{Flag,Current_Layer}= (Pre_Rond_F2Cell{Flag,Current_Layer}.*tmp_1)+(tmp_2.*tmp_3.*tmp_4) ;
% **********Calculate Rond of Backward to Cells
tmp_1=repmat(Y_FGate{Flag,Current_Layer},block_size,1);
tmp_2=repmat(g_prime(net_Cell{Flag,Current_Layer}),block_size,1);
tmp_3=repmat(Y_InGate{Flag,Current_Layer},block_size,1);
tmp_4=repmat(Pre_Y_Cell{2,Current_Layer-1}',1,block_size);
Rond_B2Cell{Flag,Current_Layer}= (Pre_Rond_B2Cell{Flag,Current_Layer}.*tmp_1)+(tmp_2.*tmp_3.*tmp_4) ;
end
%******** Calculate Rond of Cells to Cells
tmp_1=repmat(Y_FGate{Flag,Current_Layer},block_size,1);
tmp_2=repmat(g_prime(net_Cell{Flag,Current_Layer}),block_size,1);
tmp_3=repmat(Y_InGate{Flag,Current_Layer},block_size,1);
tmp_4=repmat(Pre_Y_Cell{Flag,Current_Layer}',1,block_size);
Rond_Cell2Cell{Flag,Current_Layer}= (Pre_Rond_Cell2Cell{Flag,Current_Layer}.*tmp_1)+(tmp_2.*tmp_3.*tmp_4) ;
if isequal(Current_Layer,1)
%****Calculate Rond of Input layer to Input Gates
tmp_1=repmat(Y_FGate{Flag,Current_Layer},length(tmp_input),1);
tmp_2=repmat(g_func(net_Cell{Flag,Current_Layer}),length(tmp_input),1);
tmp_3=repmat(Y_InGate{Flag,Current_Layer},length(tmp_input),1);
tmp_4=repmat(1-Y_InGate{Flag,Current_Layer},length(tmp_input),1);
tmp_5=repmat(tmp_input',1,block_size);
Rond_Input2InGate{Flag,Current_Layer}= (Pre_Rond_Input2InGate{Flag,Current_Layer}.*tmp_1)+ (tmp_2.*tmp_3.*tmp_4.*tmp_5) ;
else
% **********Calculate Rond of Forward to InGate
tmp_1=repmat(Y_FGate{Flag,Current_Layer},block_size,1);
tmp_2=repmat(g_func(net_Cell{Flag,Current_Layer}),block_size,1);
tmp_3=repmat(Y_InGate{Flag,Current_Layer},block_size,1);
tmp_4=repmat(1-Y_InGate{Flag,Current_Layer},block_size,1);
tmp_5=repmat(Pre_Y_Cell{1,Current_Layer-1}',1,block_size);
Rond_F2InGate{Flag,Current_Layer}= (Pre_Rond_F2InGate{Flag,Current_Layer}.*tmp_1)+ (tmp_2.*tmp_3.*tmp_4.*tmp_5) ;
% **********Calculate Rond of Backward to InGate
tmp_1=repmat(Y_FGate{Flag,Current_Layer},block_size,1);
tmp_2=repmat(g_func(net_Cell{Flag,Current_Layer}),block_size,1);
tmp_3=repmat(Y_InGate{Flag,Current_Layer},block_size,1);
tmp_4=repmat(1-Y_InGate{Flag,Current_Layer},block_size,1);
tmp_5=repmat(Pre_Y_Cell{2,Current_Layer-1}',1,block_size);
Rond_B2InGate{Flag,Current_Layer}= (Pre_Rond_B2InGate{Flag,Current_Layer}.*tmp_1)+ (tmp_2.*tmp_3.*tmp_4.*tmp_5) ;
end
%%% Calculate Rond of Cells to Input Gates
tmp_1=repmat(Y_FGate{Flag,Current_Layer},block_size,1);
tmp_2=repmat(g_func(net_Cell{Flag,Current_Layer}),block_size,1);
tmp_3=repmat(Y_InGate{Flag,Current_Layer},block_size,1);
tmp_4=repmat(1-Y_InGate{Flag,Current_Layer},block_size,1);
tmp_5=repmat(Pre_Y_Cell{Flag,Current_Layer}',1,block_size);
Rond_Cell2InGate{Flag,Current_Layer}= (Pre_Rond_Cell2InGate{Flag,Current_Layer}.*tmp_1)+ (tmp_2.*tmp_3.*tmp_4.*tmp_5) ;
if isequal(Current_Layer,1)
%****Calculate Rond of Input layer to Forget Gates
tmp_1=repmat(Y_FGate{Flag,Current_Layer},length(tmp_input),1);
tmp_2=repmat(Pre_Cell_State{Flag,Current_Layer},length(tmp_input),1);
tmp_3=repmat(1-Y_FGate{Flag,Current_Layer},length(tmp_input),1);
tmp_4=repmat(tmp_input',1,block_size);
Rond_Input2FGate{Flag,Current_Layer}= (Pre_Rond_Input2FGate{Flag,Current_Layer}.*tmp_1)+ (tmp_2.*tmp_1.*tmp_3.*tmp_4) ;
else
% **********Calculate Rond of Forward to FGate
tmp_1=repmat(Y_FGate{Flag,Current_Layer},block_size,1);
tmp_2=repmat(Pre_Cell_State{Flag,Current_Layer},block_size,1);
tmp_3=repmat(1-Y_FGate{Flag,Current_Layer},block_size,1);
tmp_4=repmat(Pre_Y_Cell{1,Current_Layer-1}',1,block_size);
Rond_F2FGate{Flag,Current_Layer}= (Pre_Rond_F2FGate{Flag,Current_Layer}.*tmp_1)+ (tmp_2.*tmp_1.*tmp_3.*tmp_4) ;
% **********Calculate Rond of Backward to FGate
tmp_1=repmat(Y_FGate{Flag,Current_Layer},block_size,1);
tmp_2=repmat(Pre_Cell_State{Flag,Current_Layer},block_size,1);
tmp_3=repmat(1-Y_FGate{Flag,Current_Layer},block_size,1);
tmp_4=repmat(Pre_Y_Cell{2,Current_Layer-1}',1,block_size);
Rond_B2FGate{Flag,Current_Layer}= (Pre_Rond_B2FGate{Flag,Current_Layer}.*tmp_1)+ (tmp_2.*tmp_1.*tmp_3.*tmp_4) ;
end
%%% Calculate Rond of Cells to Forget Gates
tmp_1=repmat(Y_FGate{Flag,Current_Layer},block_size,1);
tmp_2=repmat(Pre_Cell_State{Flag,Current_Layer},block_size,1);
tmp_3=repmat(1-Y_FGate{Flag,Current_Layer},block_size,1);
tmp_4=repmat(Pre_Y_Cell{Flag,Current_Layer}',1,block_size);
Rond_Cell2FGate{Flag,Current_Layer}= (Pre_Rond_Cell2FGate{Flag,Current_Layer}.*tmp_1)+ (tmp_2.*tmp_1.*tmp_3.*tmp_4) ;
%***************************Calculate tmp_delta_W's
%Delta Weights for Input to Output Gates
if isequal(Current_Layer,1)
tmp_1=repmat(tmp_input',1,block_size);
tmp_2=repmat(OutGate_Delta{Flag,Current_Layer},length(tmp_input),1);
tmp_deltaW_Input_OutGate{Flag,Current_Layer}=alfa*( tmp_1.*tmp_2);
else
tmp_1=repmat(Pre_Y_Cell{1,Current_Layer-1}',1,block_size);
tmp_2=repmat(OutGate_Delta{Flag,Current_Layer},block_size,1);
tmp_deltaW_F_OutGate{Flag,Current_Layer}=alfa*( tmp_1.*tmp_2);
tmp_1=repmat(Pre_Y_Cell{2,Current_Layer-1}',1,block_size);
tmp_2=repmat(OutGate_Delta{Flag,Current_Layer},block_size,1);
tmp_deltaW_B_OutGate{Flag,Current_Layer}=alfa*( tmp_1.*tmp_2);
end
%%% Delta Weights for Weights from Cells to Output Gates
tmp_1=repmat(Pre_Y_Cell{Flag,Current_Layer}',1,block_size);
tmp_2=repmat(OutGate_Delta{Flag,Current_Layer},block_size,1);
tmp_deltaW_Cell_OutGate{Flag,Current_Layer}=alfa*(tmp_1.*tmp_2);
%%% Delta Weights for Weights from Input Layer to Cells
if isequal(Current_Layer,1)
tmp_1=repmat(State_Error{Flag,Current_Layer},length(tmp_input),1);
tmp_deltaW_Input_Cell{Flag,Current_Layer}=alfa*(Rond_Input2Cell{Flag,Current_Layer}.*tmp_1);
else
tmp_1=repmat( State_Error{Flag,Current_Layer},block_size,1);
tmp_deltaW_F_Cell{Flag,Current_Layer}=alfa*(Rond_F2Cell{Flag,Current_Layer}.*tmp_1);
tmp_deltaW_B_Cell{Flag,Current_Layer}=alfa*(Rond_B2Cell{Flag,Current_Layer}.*tmp_1);
end
%%% Delta Weights for Weights from Cells to Cells
tmp_1=repmat( State_Error{Flag,Current_Layer},block_size,1);
tmp_deltaW_Cell_Cell{Flag,Current_Layer}=alfa*( Rond_Cell2Cell{Flag,Current_Layer}.*tmp_1);
%%% Delta Weights for Weights from Input Layer to Input Gate
if isequal(Current_Layer,1)
tmp_1=repmat( State_Error{Flag,Current_Layer},length(tmp_input),1);
tmp_deltaW_Input_InGate{Flag,Current_Layer}=alfa*(Rond_Input2InGate{Flag,Current_Layer}.*tmp_1);
else
tmp_1=repmat( State_Error{Flag,Current_Layer},block_size,1);
tmp_deltaW_F_InGate{Flag,Current_Layer}=alfa*(Rond_F2InGate{Flag,Current_Layer}.*tmp_1);
tmp_deltaW_B_InGate{Flag,Current_Layer}=alfa*(Rond_B2InGate{Flag,Current_Layer}.*tmp_1);
end
%%% Delta Weights for Weights from Cells to Input Gate
tmp_1=repmat( State_Error{Flag,Current_Layer},block_size,1);
tmp_deltaW_Cell_InGate{Flag,Current_Layer}=alfa*( Rond_Cell2InGate{Flag,Current_Layer}.*tmp_1);
%%% Delta Weights for Weights from Input Layer to Forget Gate
if isequal(Current_Layer,1)
tmp_1=repmat(State_Error{Flag,Current_Layer},length(tmp_input),1);
tmp_deltaW_Input_FGate{Flag,Current_Layer}=alfa*(Rond_Input2FGate{Flag,Current_Layer}.*tmp_1);
else
tmp_1=repmat(State_Error{Flag,Current_Layer},block_size,1);
tmp_deltaW_F_FGate{Flag,Current_Layer}=alfa*(Rond_F2FGate{Flag,Current_Layer}.*tmp_1);
tmp_deltaW_B_FGate{Flag,Current_Layer}=alfa*(Rond_B2FGate{Flag,Current_Layer}.*tmp_1);
end
%%% Delta Weights for Weights from Cells to Forget Gate
tmp_1=repmat( State_Error{Flag,Current_Layer},block_size,1);
tmp_deltaW_Cell_FGate{Flag,Current_Layer}=alfa*(Rond_Cell2FGate{Flag,Current_Layer}.*tmp_1);
% *********Update delta_W's
deltaW_Cell_Cell{Flag,Current_Layer}=deltaW_Cell_Cell{Flag,Current_Layer}+tmp_deltaW_Cell_Cell{Flag,Current_Layer};
deltaW_Cell_InGate{Flag,Current_Layer}= deltaW_Cell_InGate{Flag,Current_Layer}+tmp_deltaW_Cell_InGate{Flag,Current_Layer};
deltaW_Cell_FGate{Flag,Current_Layer}=deltaW_Cell_FGate{Flag,Current_Layer}+tmp_deltaW_Cell_FGate{Flag,Current_Layer};
deltaW_Cell_OutGate{Flag,Current_Layer}=deltaW_Cell_OutGate{Flag,Current_Layer}+tmp_deltaW_Cell_OutGate{Flag,Current_Layer};
if isequal(Current_Layer,1)
deltaW_Input_Cell{Flag,Current_Layer}=deltaW_Input_Cell{Flag,Current_Layer}+tmp_deltaW_Input_Cell{Flag,Current_Layer};
deltaW_Input_InGate{Flag,Current_Layer}=deltaW_Input_InGate{Flag,Current_Layer}+tmp_deltaW_Input_InGate{Flag,Current_Layer};
deltaW_Input_FGate{Flag,Current_Layer}=deltaW_Input_FGate{Flag,Current_Layer}+tmp_deltaW_Input_FGate{Flag,Current_Layer};
deltaW_Input_OutGate{Flag,Current_Layer}=deltaW_Input_OutGate{Flag,Current_Layer}+tmp_deltaW_Input_OutGate{Flag,Current_Layer};
else
deltaW_F_Cell{Flag,Current_Layer}=deltaW_F_Cell{Flag,Current_Layer}+tmp_deltaW_F_Cell{Flag,Current_Layer};
deltaW_B_Cell{Flag,Current_Layer}= deltaW_B_Cell{Flag,Current_Layer}+ tmp_deltaW_B_Cell{Flag,Current_Layer};
deltaW_F_InGate{Flag,Current_Layer}=deltaW_F_InGate{Flag,Current_Layer}+tmp_deltaW_F_InGate{Flag,Current_Layer};
deltaW_B_InGate{Flag,Current_Layer}=deltaW_B_InGate{Flag,Current_Layer}+tmp_deltaW_B_InGate{Flag,Current_Layer};
deltaW_F_FGate{Flag,Current_Layer}=deltaW_F_FGate{Flag,Current_Layer}+tmp_deltaW_F_FGate{Flag,Current_Layer};
deltaW_B_FGate{Flag,Current_Layer}=deltaW_B_FGate{Flag,Current_Layer}+tmp_deltaW_B_FGate{Flag,Current_Layer};
deltaW_F_OutGate{Flag,Current_Layer}=deltaW_F_OutGate{Flag,Current_Layer}+tmp_deltaW_F_OutGate{Flag,Current_Layer};
deltaW_B_OutGate{Flag,Current_Layer}=deltaW_B_OutGate{Flag,Current_Layer}+tmp_deltaW_B_OutGate{Flag,Current_Layer};
end
end
Aucun commentaire:
Enregistrer un commentaire