blob: bffa79dc2ff79597782c7e48fad2a7c52c67d23b [file] [log] [blame]
giolekvafb52e0d2020-04-23 22:52:13 +04001import sys
2import json
3import urllib.parse
4import urllib.request
5import os
6
7from facenet_pytorch import MTCNN, InceptionResnetV1
8from PIL import Image
9
10
11def detect_faces(img_file):
12 mtcnn = MTCNN(keep_all=True)
13 ret = []
14 with Image.open(img_file) as img:
15 for box in mtcnn.detect(img)[0]:
16 ret.append((box[0], box[1], box[2], box[3]))
17 return ret
18
19
20def fetch_file_for_image(gql_endpoint, object_storage_endpoint, id):
21 data = {"query": "{ getImage(id: \"" + id + "\") { objectPath } }"}
22 encoded_data = urllib.parse.urlencode(data).encode('UTF-8')
23 req = urllib.request.Request(gql_endpoint, encoded_data, method="POST")
24 resp = urllib.request.urlopen(req)
25 object_path = json.loads(resp.read())["getImage"]["objectPath"]
26 local_path = urllib.request.urlretrieve(
27 object_storage_endpoint + "/" + object_path)[0]
28 return local_path
29
30
31def format_img_segment(id, box):
32 return ("{{upperLeftX: {f[0]}, upperLeftY: {f[1]}, lowerRightX: {f[2]}, " +
33 "lowerRightY: {f[3]}, sourceImage: {{id: \"{id}\"}}}}").format(
34 f=box,
35 id=id)
36
37
38def upload_face_segments(gql_endpoint, id, faces):
39 segments = [format_img_segment(id, f) for f in faces]
40 data = {"query": "mutation {{ addImageSegment(input: [{segments}]) {{ imagesegment {{ id }} }} }}".format(
41 segments=", ".join(segments))}
42 encoded_data = urllib.parse.urlencode(data).encode('UTF-8')
43 req = urllib.request.Request(gql_endpoint, encoded_data, method="POST")
44 resp = urllib.request.urlopen(req)
45 print(resp.read())
46
47
48def main():
49 f = fetch_file_for_image(sys.argv[1], sys.argv[2], sys.argv[3])
50 faces = detect_faces(f)
51 os.remove(f)
52 upload_face_segments(sys.argv[1], sys.argv[3], faces)
53
54
55if __name__ == "__main__":
56 main()