Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Feat/temperature scaling confidence calibration #1434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
David-Magdy wants to merge 2 commits into JaidedAI:master
base: master
Choose a base branch
Loading
from David-Magdy:feat/temperature-scaling-confidence
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions easyocr/easyocr.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def recognize(self, img_cv_grey, horizontal_list=None, free_list=None,\
workers = 0, allowlist = None, blocklist = None, detail = 1,\
rotation_info = None,paragraph = False,\
contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\
y_ths = 0.5, x_ths = 1.0, reformat=True, output_format='standard'):
y_ths = 0.5, x_ths = 1.0, reformat=True, output_format='standard',temperature=1.0):

if reformat:
img, img_cv_grey = reformat_input(img_cv_grey)
Expand Down Expand Up @@ -383,15 +383,15 @@ def recognize(self, img_cv_grey, horizontal_list=None, free_list=None,\
image_list, max_width = get_image_list(h_list, f_list, img_cv_grey, model_height = imgH)
result0 = get_text(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\
ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\
workers, self.device)
workers, self.device,temperature)
result += result0
for bbox in free_list:
h_list = []
f_list = [bbox]
image_list, max_width = get_image_list(h_list, f_list, img_cv_grey, model_height = imgH)
result0 = get_text(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\
ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\
workers, self.device)
workers, self.device,temperature)
result += result0
# default mode will try to process multiple boxes at the same time
else:
Expand Down
17 changes: 12 additions & 5 deletions easyocr/recognition.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __call__(self, batch):
return image_tensors

def recognizer_predict(model, converter, test_loader, batch_max_length,\
ignore_idx, char_group_idx, decoder = 'greedy', beamWidth= 5, device = 'cpu'):
ignore_idx, char_group_idx, decoder = 'greedy', beamWidth= 5, device = 'cpu', temperature=1.0):
model.eval()
result = []
with torch.no_grad():
Expand All @@ -110,6 +110,9 @@ def recognizer_predict(model, converter, test_loader, batch_max_length,\

preds = model(image, text_for_pred)

# Apply Temperature Scaling
preds = preds / temperature

# Select max probabilty (greedy decoding) then decode index to character
preds_size = torch.IntTensor([preds.size(1)] * batch_size)

Expand Down Expand Up @@ -145,7 +148,11 @@ def recognizer_predict(model, converter, test_loader, batch_max_length,\
preds_max_prob.append(np.array([0]))

for pred, pred_max_prob in zip(preds_str, preds_max_prob):
confidence_score = custom_mean(pred_max_prob)
# Replaced custom_mean with a standard average of probabilities
if len(pred_max_prob) == 0:
confidence_score = 0.0
else:
confidence_score = pred_max_prob.mean()
result.append([pred, confidence_score])

return result
Expand Down Expand Up @@ -185,7 +192,7 @@ def get_recognizer(recog_network, network_params, character,\

def get_text(character, imgH, imgW, recognizer, converter, image_list,\
ignore_char = '',decoder = 'greedy', beamWidth =5, batch_size=1, contrast_ths=0.1,\
adjust_contrast=0.5, filter_ths = 0.003, workers = 1, device = 'cpu'):
adjust_contrast=0.5, filter_ths = 0.003, workers = 1, device = 'cpu', temperature=1.0):
batch_max_length = int(imgW/10)

char_group_idx = {}
Expand All @@ -204,7 +211,7 @@ def get_text(character, imgH, imgW, recognizer, converter, image_list,\

# predict first round
result1 = recognizer_predict(recognizer, converter, test_loader,batch_max_length,\
ignore_idx, char_group_idx, decoder, beamWidth, device = device)
ignore_idx, char_group_idx, decoder, beamWidth, device=device, temperature=temperature)

# predict second round
low_confident_idx = [i for i,item in enumerate(result1) if (item[1] < contrast_ths)]
Expand All @@ -216,7 +223,7 @@ def get_text(character, imgH, imgW, recognizer, converter, image_list,\
test_data, batch_size=batch_size, shuffle=False,
num_workers=int(workers), collate_fn=AlignCollate_contrast, pin_memory=True)
result2 = recognizer_predict(recognizer, converter, test_loader, batch_max_length,\
ignore_idx, char_group_idx, decoder, beamWidth, device = device)
ignore_idx, char_group_idx, decoder, beamWidth, device=device, temperature=temperature)

result = []
for i, zipped in enumerate(zip(coord, result1)):
Expand Down

AltStyle によって変換されたページ (->オリジナル) /