clear all
close all

% clustering

SpecificID = '1';

%%
SaveDir = 'ClusteringFigure';

DetectionEfficiency = 80;
ExclStr = '000';

AbsStartCoord = 1;
AbsEndCoord = 34;
CompStartCoord = 1;
CompEndCoord = 34;
PHRRatio = 25;

SaveName = [SpecificID '_DE' num2str(DetectionEfficiency) '_Excl-' ExclStr '_' num2str(AbsStartCoord) '-' num2str(AbsEndCoord) '_' num2str(CompStartCoord) '-' num2str(CompEndCoord) '_PHR' num2str(PHRRatio)];


%%
mkdir(SaveDir)
load('ChrC57List.mat')
load('ChrCastList.mat')

TotalNumTADs = 34;

usedArea = [CompStartCoord:CompEndCoord];
absArea = [AbsStartCoord:AbsEndCoord];

%% C57
Chr = ChrC57;
Chr_chozen = [];
for k = 1:length(Chr)
    if sum(Chr(k).r(absArea)) >= floor((DetectionEfficiency/100)*length(absArea)) ...
            && sum(Chr(k).r(usedArea)) >= floor((DetectionEfficiency/100)*length(usedArea)) % 80% or higher detection efficiency
        
        RprofileChar = [];
        Rprofile = Chr(k).r(absArea);
        for i = 1:length(Rprofile)
            RprofileChar = [RprofileChar num2str(Rprofile(i))];
        end
        toExl = contains(RprofileChar, ExclStr);
        if toExl
            continue
        end
        
        Chr_chozen = [Chr_chozen Chr(k)];
    end
end
Chr = Chr_chozen;

% StackedTracesC57 = nan(length(Chr),33); % nan(length(Chr),(NumTADsToUse*(NumTADsToUse-1)/2)); % StackedTracesC57 = nan(length(Chr),(TotalNumTADs*(TotalNumTADs-1)/2));
% StackedTracesC57 = nan(length(Chr),(NumTADsToUse*(NumTADsToUse-1)/2));
StackedTracesC57 = nan(length(Chr),10);

usedCoord = zeros(TotalNumTADs,TotalNumTADs);
for i = 1:TotalNumTADs
    usedCoord(i,i) = 2;
end


for k = 1:length(Chr)
    Mean = zeros(TotalNumTADs,TotalNumTADs);
    for i = 1:TotalNumTADs
        for j = 1:TotalNumTADs
            if Chr(k).r(i) == 1 && Chr(k).r(j) == 1
                Mean(i,j) = ((Chr(k).x(i)-Chr(k).x(j))^2+(Chr(k).y(i)-Chr(k).y(j))^2+(Chr(k).z(i)-Chr(k).z(j))^2)^0.5;
            else
                Mean(i,j) = NaN;
            end
        end
    end
    
    Mean_filtered = fillmissing(Mean,'linear');
    Mean_filtered = fillmissing(Mean_filtered,'linear',2);
    
    ind = 0;
    for i = AbsStartCoord:AbsEndCoord-1 % 1:TotalNumTADs-1
        for ii = i+1:AbsEndCoord % TotalNumTADs
            if (sum(i == [CompStartCoord:CompEndCoord])>0 || sum(ii == [CompStartCoord:CompEndCoord])>0)
                usedCoord(i,ii) = 1;
                ind = ind+1;
                StackedTracesC57(k,ind) = Mean_filtered(i,ii);
            end
        end
    end
    
%     StackedTracesC57(k,1:8) = Mean_filtered(9,1:8);
%     StackedTracesC57(k,9:33) = Mean_filtered(9,10:34);
    
end
PHRC57 = [Chr.PHR];

figure(1111)
imagesc(usedCoord)
colormap gray
axis square
saveas(gcf, [SaveDir '\Clustering_' SaveName '_SelectedContacts.jpg'])

%% CAST
Chr = ChrCast;
Chr_chozen = [];
for k = 1:length(Chr)
    if sum(Chr(k).r)>=23 % 80% or higher detection efficiency
        Chr_chozen = [Chr_chozen Chr(k)];
    end
end
Chr = Chr_chozen;

% StackedTracesCAST = nan(length(Chr),33); % nan(length(Chr),(NumTADsToUse*(NumTADsToUse-1)/2)); % nan(length(Chr),(TotalNumTADs*(TotalNumTADs-1)/2));
% StackedTracesCAST = nan(length(Chr),(NumTADsToUse*(NumTADsToUse-1)/2));
StackedTracesCAST = nan(length(Chr),10);

for k = 1:length(Chr)
    Mean = zeros(TotalNumTADs,TotalNumTADs);
    for i = 1:TotalNumTADs
        for j = 1:TotalNumTADs
            if Chr(k).r(i) == 1 && Chr(k).r(j) == 1
                Mean(i,j) = ((Chr(k).x(i)-Chr(k).x(j))^2+(Chr(k).y(i)-Chr(k).y(j))^2+(Chr(k).z(i)-Chr(k).z(j))^2)^0.5;
            else
                Mean(i,j) = NaN;
            end
        end
    end
    
    Mean_filtered = fillmissing(Mean,'linear');
    Mean_filtered = fillmissing(Mean_filtered,'linear',2);
    
    ind = 0;
    for i = AbsStartCoord:AbsEndCoord-1 % 1:TotalNumTADs-1
        for ii = i+1:AbsEndCoord % TotalNumTADs
            if (sum(i == [CompStartCoord:CompEndCoord])>0 || sum(ii == [CompStartCoord:CompEndCoord])>0)
                usedCoord(i,ii) = 1;
                ind = ind+1;
                StackedTracesCAST(k,ind) = Mean_filtered(i,ii);
            end
        end
    end
    
%     StackedTracesCAST(k,1:8) = Mean_filtered(9,1:8);
%     StackedTracesCAST(k,9:33) = Mean_filtered(9,10:34);
    
end
PHRCAST = [Chr.PHR];


%% Do clustering


% true ID
StackedTracesAll = [StackedTracesCAST; StackedTracesC57];
TracesTrueID = zeros(size(StackedTracesAll,1),1);
TracesTrueID(1:size(StackedTracesCAST,1)) = 1;
TracesTrueID(size(StackedTracesCAST,1)+1:end) = 2;

% PHR ID
PHRAll = [PHRCAST'; PHRC57'];
percentileUsed = PHRRatio;
TopPHRThresh = prctile(PHRAll,100-percentileUsed);
BotPHRThresh = prctile(PHRAll,percentileUsed);

bellowThresh = find(PHRAll <= BotPHRThresh);
aboveThresh = find(PHRAll >= TopPHRThresh);

PHRID = zeros(size(PHRAll));
PHRID(:) = 1;
PHRID(bellowThresh) = 2;
PHRID(aboveThresh) = 3;


middle50 = find(PHRID == 1);
top25 = find(PHRID == 2);
bot25 = find(PHRID == 3);
PHRBackToTrueID = TracesTrueID;
PHRBackToTrueID(PHRBackToTrueID == 2) = 3;
PHRBackToTrueID(PHRBackToTrueID == 1) = 2;
PHRBackToTrueID(middle50) = 1;

c = linspace(min(PHRAll),max(PHRAll),length(PHRAll));
[~,I2] = sort(PHRAll);
c = c(I2);


%%  TSNE
TSNEoutput = tsne(StackedTracesAll,'Distance','cosine'); % the cosine distance functions as a normalization

% figure(10)

indLow = find(PHRID == 2);
indHigh = find(PHRID == 3);
HighTSNE = TSNEoutput(indHigh,:);
LowTSNE = TSNEoutput(indLow,:);


indCAST = find(TracesTrueID == 1);
indC57 = find(TracesTrueID == 2);
TSNECAST = TSNEoutput(indCAST,:);
TSNEC57 = TSNEoutput(indC57,:);


figure(100)
subplot(1,4,1)
gscatter(TSNEoutput(:,1),TSNEoutput(:,2),TracesTrueID,[1 0 1; 0 1 1],'.',6);
legend('off');
xlabel('tsne1');
ylabel('tsne2');
axis square

figure(100)
subplot(1,4,2)
gscatter(TSNEoutput(:,1),TSNEoutput(:,2),PHRID,[1 1 1; 1 0 1; 0 1 1],'.',6);
% gscatter(TSNEoutput(:,1),TSNEoutput(:,2),PHRBackToTrueID,[1 1 1; 1 0 1; 0 1 1]);
legend('off');
xlabel('tsne1');
ylabel('tsne2');
axis square

figure(100)
subplot(1,4,3)
gscatter(TSNEoutput(:,1),TSNEoutput(:,2),PHRID,[0.5 0.5 0.5; 1 0 1; 0 1 1],'.',6);
% gscatter(TSNEoutput(:,1),TSNEoutput(:,2),PHRBackToTrueID,[1 1 1; 1 0 1; 0 1 1]);
legend('off');
xlabel('tsne1');
ylabel('tsne2');
axis square

figure(100)
subplot(1,4,4)
scatter(TSNEoutput(:,1),TSNEoutput(:,2),5,c,'filled')
legend('off');
xlabel('tsne1');
ylabel('tsne2');
axis square

figureSaveNameAdd = ['DE' num2str(DetectionEfficiency) ' Excl-' ExclStr ' ' num2str(AbsStartCoord) '-' num2str(AbsEndCoord) ' ' num2str(CompStartCoord) '-' num2str(CompEndCoord) ' PHR' num2str(PHRRatio)];

figure(100)
sgtitle([figureSaveNameAdd ' - Mat. CAST: ' num2str(size(StackedTracesCAST,1)) 'Pat. C57: ' num2str(size(StackedTracesC57,1))], 'FontSize', 8)

% saveas(gcf, [SaveDir '\Clustering_' SaveName '.jpg'])
print(gcf, [SaveDir '\Clustering_' SaveName '.jpg'], '-djpeg', '-r1000');


%% - - - - - - - - - - 
%
% Even out numbers of traces per copy
%
% - - - - - - - - - - 


numToSelect = min([size(StackedTracesCAST,1), size(StackedTracesC57,1)]);
lengthOfMax = max([size(StackedTracesCAST,1), size(StackedTracesC57,1)]);

randomIndCAST = randperm(size(StackedTracesCAST,1), numToSelect);
randomIndCAST = sort(randomIndCAST,'ascend');

randomIndC57 = randperm(size(StackedTracesC57,1), numToSelect);
randomIndC57 = sort(randomIndC57,'ascend');

% true ID
StackedTracesAllSelect = [StackedTracesCAST(randomIndCAST,:); StackedTracesC57(randomIndC57,:)];
TracesTrueIDSelect = zeros(size(StackedTracesAllSelect,1),1);
TracesTrueIDSelect(1:numToSelect) = 1;
TracesTrueIDSelect(numToSelect+1:end) = 2;

% PHR ID
PHRAllSelect = [PHRCAST(randomIndCAST)'; PHRC57(randomIndC57)'];
percentileUsed = PHRRatio;
TopPHRThreshSelect = prctile(PHRAllSelect,100-percentileUsed);
BotPHRThreshSelect = prctile(PHRAllSelect,percentileUsed);

bellowThreshSelect = find(PHRAllSelect <= BotPHRThreshSelect);
aboveThreshSelect = find(PHRAllSelect >= TopPHRThreshSelect);

PHRIDSelect = zeros(size(PHRAllSelect));
PHRIDSelect(:) = 1;
PHRIDSelect(bellowThreshSelect) = 2;
PHRIDSelect(aboveThreshSelect) = 3;

% TSNE
TSNEoutputSelect = tsne(StackedTracesAllSelect,'Distance','cosine'); % the cosine distance functions as a normalization

indLowSelect = find(PHRIDSelect == 2);
indHighSelect = find(PHRIDSelect == 3);

HighTSNESelect = TSNEoutputSelect(indHighSelect,:);
LowTSNESelect = TSNEoutputSelect(indLowSelect,:);


cSelect = linspace(min(PHRAllSelect),max(PHRAllSelect),length(PHRAllSelect));
[~,I2] = sort(PHRAllSelect);
cSelect = cSelect(I2);

%%
% figure(10)

indMid = find(PHRIDSelect == 1);
indLow = find(PHRIDSelect == 2);
indHigh = find(PHRIDSelect == 3);
HighTSNE = TSNEoutputSelect(indHigh,:);
LowTSNE = TSNEoutputSelect(indLow,:);
MidTSNE = TSNEoutputSelect(indMid,:);

indCAST = find(TracesTrueIDSelect == 1);
indC57 = find(TracesTrueIDSelect == 2);
CASTTSNE = TSNEoutputSelect(indCAST,:);
C57TSNE = TSNEoutputSelect(indC57,:);

figure(111)
histogram(PHRAllSelect,80)

sortedPHRSelect = sort(PHRAllSelect,'ascend');

xSpots = 0:0.01:13;
xRange = 0.3;
PHRHist = [];
for i = xSpots
    ind1 = sortedPHRSelect >= i-xRange;
    ind2 = sortedPHRSelect <= i+xRange;
    ind = ind1 == ind2;
    % ind = boolean(ind);
    PHRHist(end+1) = sum(ind);
end
PHRHist = PHRHist';
xSpots = xSpots';


figure(200)
subplot(1,4,1)
gscatter(TSNEoutputSelect(:,1),TSNEoutputSelect(:,2),TracesTrueIDSelect,[1 0 1; 0 1 1],'.',6);
legend('off');
xlabel('tsne1');
ylabel('tsne2');
axis square

figure(200)
subplot(1,4,2)
gscatter(TSNEoutputSelect(:,1),TSNEoutputSelect(:,2),PHRIDSelect,[1 1 1; 1 0 1; 0 1 1],'.',6);
% gscatter(TSNEoutputSelect(:,1),TSNEoutputSelect(:,2),PHRBackToTrueID,[1 1 1; 1 0 1; 0 1 1]);
legend('off');
xlabel('tsne1');
ylabel('tsne2');
axis square

figure(200)
subplot(1,4,3)
gscatter(TSNEoutputSelect(:,1),TSNEoutputSelect(:,2),PHRIDSelect,[0.5 0.5 0.5; 1 0 1; 0 1 1],'.',6);
% gscatter(TSNEoutputSelect(:,1),TSNEoutputSelect(:,2),PHRBackToTrueID,[1 1 1; 1 0 1; 0 1 1]);
legend('off');
xlabel('tsne1');
ylabel('tsne2');
axis square

figure(200)
subplot(1,4,4)
scatter(TSNEoutputSelect(:,1),TSNEoutputSelect(:,2),5,cSelect,'filled')
legend('off');
xlabel('tsne1');
ylabel('tsne2');
axis square

figureSaveNameAdd = ['DE' num2str(DetectionEfficiency) ' Excl-' ExclStr ' ' num2str(AbsStartCoord) '-' num2str(AbsEndCoord) ' ' num2str(CompStartCoord) '-' num2str(CompEndCoord) ' PHR' num2str(PHRRatio)];

figure(200)
sgtitle([figureSaveNameAdd ' - Mat. CAST: ' num2str(length(randomIndCAST)) 'Pat. C57: ' num2str(length(randomIndC57))], 'FontSize', 8)

% saveas(gcf, [SaveDir '\Clustering_' SaveName '_Equalized.jpg'])
print(gcf, [SaveDir '\Clustering_' SaveName '_Equalized.jpg'], '-djpeg', '-r1000');

save([SaveDir '\Clustering_' SaveName])
