| giolekva | b0b7f00 | 2020-04-16 16:05:40 +0400 | [diff] [blame^] | 1 | import os |
| 2 | import sys |
| giolekva | 550add7 | 2020-04-11 20:01:40 +0400 | [diff] [blame] | 3 | |
| giolekva | b0b7f00 | 2020-04-16 16:05:40 +0400 | [diff] [blame^] | 4 | from facenet_pytorch import MTCNN, InceptionResnetV1 |
| 5 | from PIL import Image |
| giolekva | 550add7 | 2020-04-11 20:01:40 +0400 | [diff] [blame] | 6 | |
| giolekva | 550add7 | 2020-04-11 20:01:40 +0400 | [diff] [blame] | 7 | |
| giolekva | b0b7f00 | 2020-04-16 16:05:40 +0400 | [diff] [blame^] | 8 | def detect(input_dir, output_dir): |
| 9 | mtcnn = MTCNN(keep_all=True) |
| 10 | resnet = InceptionResnetV1(pretrained='vggface2').eval() |
| 11 | for f in os.listdir(input_dir): |
| 12 | with Image.open(input_dir + "/" + f) as img: |
| 13 | # if img.filename != "input/P7260028.jpg": |
| 14 | # continue |
| 15 | print(img.filename) |
| 16 | for m in mtcnn(img): |
| 17 | print(resnet(m)) |
| 18 | |
| 19 | # embedding = resnet(mtcnn(img)) |
| 20 | # print(len(embedding[0])) |
| 21 | |
| 22 | # boxes, _ = mtcnn.detect(img) |
| 23 | # for i, box in enumerate(boxes): |
| 24 | # cropped = img.crop(box) |
| 25 | # cropped.save(output_dir + "/" + str(i) + "_" + f) |
| 26 | |
| 27 | |
| 28 | def classify(input_dir, output_dir): |
| 29 | mtcnn = MTCNN() |
| 30 | resnet = InceptionResnetV1(pretrained='vggface2').eval() |
| 31 | for f in os.listdir(input_dir): |
| 32 | with Image.open(input_dir + "/" + f) as img: |
| 33 | print(img.filename) |
| 34 | embedding = resnet(mtcnn(img)) |
| 35 | print(len(embedding[0])) |
| 36 | |
| 37 | |
| 38 | def main(): |
| 39 | if sys.argv[1] == "detect": |
| 40 | detect(sys.argv[2], sys.argv[3]) |
| 41 | else: |
| 42 | classify(sys.argv[2], sys.argv[3]) |
| 43 | |
| 44 | |
| 45 | if __name__ == "__main__": |
| 46 | main() |