blob: 850195714782dbcabde95b6289955d279452a007 [file] [log] [blame]
giolekvab0b7f002020-04-16 16:05:40 +04001import os
2import sys
giolekva550add72020-04-11 20:01:40 +04003
giolekvab0b7f002020-04-16 16:05:40 +04004from facenet_pytorch import MTCNN, InceptionResnetV1
5from PIL import Image
giolekva550add72020-04-11 20:01:40 +04006
giolekva550add72020-04-11 20:01:40 +04007
giolekvab0b7f002020-04-16 16:05:40 +04008def detect(input_dir, output_dir):
9 mtcnn = MTCNN(keep_all=True)
10 resnet = InceptionResnetV1(pretrained='vggface2').eval()
11 for f in os.listdir(input_dir):
12 with Image.open(input_dir + "/" + f) as img:
13 # if img.filename != "input/P7260028.jpg":
14 # continue
15 print(img.filename)
16 for m in mtcnn(img):
17 print(resnet(m))
18
19 # embedding = resnet(mtcnn(img))
20 # print(len(embedding[0]))
21
22 # boxes, _ = mtcnn.detect(img)
23 # for i, box in enumerate(boxes):
24 # cropped = img.crop(box)
25 # cropped.save(output_dir + "/" + str(i) + "_" + f)
26
27
28def classify(input_dir, output_dir):
29 mtcnn = MTCNN()
30 resnet = InceptionResnetV1(pretrained='vggface2').eval()
31 for f in os.listdir(input_dir):
32 with Image.open(input_dir + "/" + f) as img:
33 print(img.filename)
34 embedding = resnet(mtcnn(img))
35 print(len(embedding[0]))
36
37
38def main():
39 if sys.argv[1] == "detect":
40 detect(sys.argv[2], sys.argv[3])
41 else:
42 classify(sys.argv[2], sys.argv[3])
43
44
45if __name__ == "__main__":
46 main()