#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on on Fri Dec 18 07:52:06 2020
@author: Enzo D’Andrea, Giuditta Davini
Workflow of Figure 3 showing the design and operational 
schematic diagram of the AI as a service (AiaaS) platform
used in this study.

AI_workflow.py
"""
#%% packets

from azure.cognitiveservices.vision.customvision.training import CustomVisionTrainingClient
from azure.cognitiveservices.vision.customvision.prediction import CustomVisionPredictionClient
from azure.cognitiveservices.vision.customvision.training.models import ImageFileCreateBatch, ImageFileCreateEntry, Region
from msrest.authentication import ApiKeyCredentials
import time
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import tifffile as tifi
from PIL import Image
import sklearn.metrics as metrics
#%% digital slide cut code
    
start_img = tifi.imread('input path')
output = 'output path'

imgset = [{'name': 'img_', 'x0': 0, 'x1': start_img.shape[0], 'y0': 0, 'y1': start_img.shape[1]}]
aproxpixel=2200  #arbitrary input

x=int(start_img.shape[1]/aproxpixel) #give the number of cuts on x axis
y=int(start_img.shape[0]/aproxpixel) #same as previus but on y

cutx=np.linspace(0, start_img.shape[1], num=x, endpoint=True, retstep=False, dtype=int)
cuty=np.linspace(0, start_img.shape[0], num=y, endpoint=True, retstep=False, dtype=int)

#image cutting cicle
cutimg=[]
ix=-1
iy=0
while ix<len(cutx)-2:
    ix+=1
    iy=0
    while iy<len(cuty)-1:
        img=start_img[cuty[iy]:cuty[iy+1], cutx[ix]:cutx[ix+1]]
        cutimg.append(img)
        iy+=1   
len(cutimg)

i=1
for img in cutimg:
    cv2.imwrite(output+"\\"+"N6 "+str(i)+".jpeg",cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
    i+=1
    print(i)


#%% test on created images

endpoint = #insert here your endpoint
training_key = #insert here your training key
prediction_key = #insert here your prediction key
prediction_resource_id = #insert here your  prediction path id
project_id = #insert here your project ID
publish_iteration_name = #insert here your iteration name

credentials = ApiKeyCredentials(in_headers={"Training-key": training_key})
trainer = CustomVisionTrainingClient(endpoint, credentials)
prediction_credentials = ApiKeyCredentials(in_headers={"Prediction-key": prediction_key})
predictor = CustomVisionPredictionClient(endpoint, prediction_credentials)

testing_set_directory = #"test folder path"
  

list_predictions = list()
for (directory, folders, files) in os.walk(testing_set_directory):
    for f in files:

        print(os.path.basename(directory), os.path.splitext(f)[0])
        with open(os.path.join(directory, f), "rb") as image_contents:
            results = predictor.classify_image(project_id, publish_iteration_name, image_contents.read())
            list_predictions = list_predictions + [{'image': os.path.splitext(f)[0], 'folder': os.path.basename(directory), 'probability': p['probability'], 'tag': p['tag_name']} for p in results.as_dict()['predictions']]
        
df_predictions = \
    pd.DataFrame(list_predictions) \
    .assign(pivot_index = lambda x: x.folder + '\\' + x.image) \
    .pivot(index = 'pivot_index', columns = 'tag', values = 'probability') \
    .reset_index() \

df_predictions.insert(0, 'folder', df_predictions.pivot_index.str.split('\\', 1).apply(lambda y: y[0]))
df_predictions.insert(1, 'image', df_predictions.pivot_index.str.split('\\', 1).apply(lambda y: y[1]))

df_predictions2 = \
    df_predictions \
    .drop(columns = 'pivot_index') \
    .assign(Class = lambda x: np.select([x.image.str.contains('\+'), x.image.str.contains('\-|neg|N')], ['pos', 'neg'])) \
    .rename(columns = {'Positive': 'AI Positive'}) \
    .loc[:, ['Class', 'AI Positive', 'image' ]]

#%% counts how many values are over a specific value selected by the user
def count(i,df): #counts the rows over a value 
    z=df[['AI Positive']]
    z=z.loc[z['AI Positive'] >i]
    n_row1=len(z) 
    return(n_row1)

i=0.9

b=count(i,df_predictions2)
print(b)
#%% ROC
z = df_predictions2 \
    .sort_values(['AI Positive'], ascending = False) \
    .assign(p = lambda x: (x.Class == 'pos').cumsum()) \
    .assign(n = lambda x: (x.Class == 'neg').cumsum()) \
    .assign(Actual = (z.Class == 'pos').astype(int))

pos = z.p/(z.Class == 'pos').sum()
neg = z.n/(z.Class == 'neg').sum()

actual = z.Actual
predictions = z['AI Positive']

fpr, tpr, thresh = metrics.roc_curve(actual, predictions)

auc = metrics.roc_auc_score(actual, predictions)

plt.plot(fpr,tpr, color='red', label='AUC = %0.2f'% auc, linewidth=0.0)
plt.figure(figsize=(16,9))
plt.plot([0,1],[0,1],'r--', color='red')

plt.legend(loc=0)
plt.plot(neg, pos)
plt.title('ROC curve')
plt.xlabel('false positive')
plt.ylabel('true positive')
plt.show()
#%%SVG GRAPH
auc = metrics.roc_auc_score(actual, predictions)
auc_reduced= round(auc, 4)
plt.figure(figsize=(3.34,3.34))
plt.plot([0,1],[0,1],'r--', color='red')
plt.text(0.8, 0, s='AUC='+str(auc_reduced),fontsize='xx-small')
plt.plot(neg, pos)
plt.title('ROC curve', fontsize='xx-small')
plt.xlabel('False positive', fontsize='xx-small')
plt.ylabel('True positive', fontsize='xx-small')
plt.xticks(fontsize='xx-small')
plt.yticks(fontsize='xx-small')
plt.savefig('output path',dpi=300)
#%% CONFUSION MATRIX
confusion_matrix = \
    df_predictions \
    .assign(ActualClass = lambda x: np.where(x.folder == 'N', 'Negative', 'Positive')) \
    .assign(Predicted_Positive = lambda x: (x.Positive > 0.05).astype(int)) \
    .assign(Predicted_Negative  = lambda x: (x.Negative >= 0.95).astype(int)) \
    .groupby(['ActualClass'])[['Predicted_Positive', 'Predicted_Negative']] \
    .sum()

print(confusion_matrix)
#%% matrics

def metriche(cm):
    TN=cm.iat[0,1] #true positive
    TP=cm.iat[1,0] #true negative
    FP=cm.iat[0,0] #false positive
    FN=cm.iat[1,1] #false negative
    f1=(2*TP)/(2*TP+FP+FN) #f1 score
    trp=TP/(TP+FN) #true positive rate measures the proportion of true positives that are correctly identified
    tnr=TN/(TN+FP) #true negative rate measures the proportion of true negatives (e.g. the proportion of those who truly do not have the condition (unaffected) who are correctly identified as not having the condition)
    acc=(TP+TN)/(TP+TN+FP+FN) #accuracy
    bacc=(trp+tnr)/2 #balanced acuracy
    MCC=((TP*TN)-(FP*FN))/(((TP+FP)*(TP+FN)*(TN+FP)*(TN+FN))**0.5) #Matthews correlation coefficient (MCC)
    FM=((TP/(TP+FP))*(TP/(TP+FN)))**0.5 #Fowlkes–Mallows index (FM)
    return("true positive: "+ str(TP),
           "true negative: "+ str(TN),
           "false positive: "+ str(FP),
           "false negative: "+ str(FN),
           "f1: " + str(f1),
           "true positive rate: " + str(trp), 
           "true negative rate: " + str(tnr),
           "accuracy: " + str(acc),
           "balance accuracy: " + str(bacc),
           "Matthews correlation coefficient: " + str(MCC),
           "Fowlkes–Mallows index: " + str(FM)
           )


met=metriche(confusion_matrix)
print(met)
