vertalius commited on
Commit
f3457a8
·
verified ·
1 Parent(s): 159d46f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -41
app.py CHANGED
@@ -4,6 +4,8 @@ import numpy as np
4
  import tempfile
5
  from typing import Optional, Tuple
6
  from datetime import datetime
 
 
7
 
8
  from pose_detector import PoseDetector
9
  from skeleton_generator import SkeletonGenerator
@@ -68,14 +70,13 @@ def init_page():
68
  key='export_format'
69
  )
70
 
71
- # Инициализируем значение manual_correction только при первом запуске
72
  if "manual_correction" not in st.session_state:
73
  st.session_state.manual_correction = st.sidebar.checkbox("Enable Manual Corrections")
74
  else:
75
  st.session_state.manual_correction = st.sidebar.checkbox("Enable Manual Corrections", value=st.session_state.manual_correction)
76
 
77
  if st.session_state.manual_correction:
78
- st.sidebar.info("Click on landmarks in the preview to adjust their positions")
79
 
80
  st.title("Pose Detection & Animation Generator")
81
  return confidence_threshold
@@ -118,7 +119,7 @@ def main():
118
  with col1:
119
  st.subheader("Original")
120
  with col2:
121
- st.subheader("Processed")
122
 
123
  try:
124
  if file_type == 'image' and not is_gif:
@@ -140,52 +141,94 @@ def main():
140
  db.close()
141
 
142
  def process_image_upload(uploaded_file, components, processed_file, db, col1, col2):
143
- """Handle image file upload processing."""
144
  pose_detector, skeleton_generator, animation_exporter = components
145
 
146
- file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
147
- image = cv2.imdecode(file_bytes, 1)
 
 
 
 
 
148
 
149
  with col1:
150
  st.image(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), use_column_width=True)
 
 
 
 
 
 
 
 
 
151
 
152
- processed_image, skeleton_data = process_image(image, pose_detector, skeleton_generator)
153
-
154
- if not skeleton_data:
155
- raise ValueError("No pose detected in the image")
156
-
157
  save_pose_data(db, processed_file.id, skeleton_data)
158
  animation_data_binary = animation_exporter.export_pose(skeleton_data)
159
  save_animation_data(db, processed_file.id, skeleton_data)
160
 
161
  with col2:
 
162
  processed_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
163
- canvas_container = st.empty()
164
- canvas_container.image(processed_rgb, use_column_width=True)
165
-
166
- if st.session_state.get('manual_correction', False):
167
- # Инициализируем текущие координаты, если ещё не заданы
168
- if "current_landmarks" not in st.session_state or st.session_state.current_landmarks is None:
169
- st.session_state.current_landmarks = skeleton_data.copy()
170
- joints = st.session_state.current_landmarks
171
-
172
- selected_joint = st.selectbox("Select Joint to Adjust", list(joints.keys()), key="selected_joint")
173
-
174
- # Используем уникальные ключи для слайдеров
175
- x_pos = st.slider("X Position", 0.0, 1.0, float(joints[selected_joint]['position'][0]), 0.01, key=f"x_{selected_joint}")
176
- y_pos = st.slider("Y Position", 0.0, 1.0, float(joints[selected_joint]['position'][1]), 0.01, key=f"y_{selected_joint}")
177
-
178
- apply_clicked = st.button("Apply Changes", key="apply_changes")
179
- save_clicked = st.button("Save Corrections", key="save_corrections")
180
-
181
- if apply_clicked:
182
- joints[selected_joint]['position'] = [x_pos, y_pos]
183
- st.session_state.current_landmarks = joints # обновляем значение в session_state
184
- processed_image = pose_detector.draw_corrected_pose(image, joints)
185
- processed_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
186
- canvas_container.image(processed_rgb, use_column_width=True)
187
 
188
- if save_clicked:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  save_corrected_pose(db, processed_file.id, joints)
190
  st.success("Corrections saved successfully!")
191
 
@@ -252,11 +295,13 @@ def save_corrected_pose(db, file_id: int, joints: dict):
252
  def show_instructions():
253
  with st.expander("Instructions"):
254
  st.markdown("""
255
- 1. Upload an image/video using the file uploader
256
- 2. Wait for processing to complete
257
- 3. Preview results in the right column
258
- 4. Download animation data
259
-
 
 
260
  Supported formats:
261
  - Images: JPG, PNG
262
  - Videos: MP4, GIF
 
4
  import tempfile
5
  from typing import Optional, Tuple
6
  from datetime import datetime
7
+ from PIL import Image
8
+ from streamlit_drawable_canvas import st_canvas
9
 
10
  from pose_detector import PoseDetector
11
  from skeleton_generator import SkeletonGenerator
 
70
  key='export_format'
71
  )
72
 
 
73
  if "manual_correction" not in st.session_state:
74
  st.session_state.manual_correction = st.sidebar.checkbox("Enable Manual Corrections")
75
  else:
76
  st.session_state.manual_correction = st.sidebar.checkbox("Enable Manual Corrections", value=st.session_state.manual_correction)
77
 
78
  if st.session_state.manual_correction:
79
+ st.sidebar.info("Click on the preview canvas to select a joint")
80
 
81
  st.title("Pose Detection & Animation Generator")
82
  return confidence_threshold
 
119
  with col1:
120
  st.subheader("Original")
121
  with col2:
122
+ st.subheader("Processed (click to select joint)")
123
 
124
  try:
125
  if file_type == 'image' and not is_gif:
 
141
  db.close()
142
 
143
  def process_image_upload(uploaded_file, components, processed_file, db, col1, col2):
144
+ """Handle image file upload processing with persistent state and interactive joint selection."""
145
  pose_detector, skeleton_generator, animation_exporter = components
146
 
147
+ # Сохраняем исходное изображение в session_state
148
+ if "uploaded_image" not in st.session_state:
149
+ file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
150
+ image = cv2.imdecode(file_bytes, 1)
151
+ st.session_state.uploaded_image = image
152
+ else:
153
+ image = st.session_state.uploaded_image
154
 
155
  with col1:
156
  st.image(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), use_column_width=True)
157
+
158
+ # Обработка изображения для получения позы и скелета
159
+ if "original_skeleton_data" not in st.session_state:
160
+ processed_image, skeleton_data = process_image(image, pose_detector, skeleton_generator)
161
+ st.session_state.original_skeleton_data = skeleton_data
162
+ st.session_state.processed_image = processed_image
163
+ else:
164
+ skeleton_data = st.session_state.original_skeleton_data
165
+ processed_image = st.session_state.processed_image
166
 
167
+ # Сохраняем данные поз в БД
 
 
 
 
168
  save_pose_data(db, processed_file.id, skeleton_data)
169
  animation_data_binary = animation_exporter.export_pose(skeleton_data)
170
  save_animation_data(db, processed_file.id, skeleton_data)
171
 
172
  with col2:
173
+ height, width = processed_image.shape[:2]
174
  processed_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
175
+ pil_image = Image.fromarray(processed_rgb)
176
+
177
+ # Отображаем изображение через st_canvas, чтобы можно было кликать по нему.
178
+ canvas_result = st_canvas(
179
+ fill_color="rgba(0, 0, 0, 0)", # прозрачный фон для рисования
180
+ stroke_width=5,
181
+ stroke_color="#FF0000",
182
+ background_image=pil_image,
183
+ update_streamlit=True,
184
+ height=height,
185
+ width=width,
186
+ drawing_mode="point", # режим для регистрации кликов
187
+ key="canvas"
188
+ )
189
+
190
+ # Если пользователь сделал клик, в canvas_result.json_data появятся объекты.
191
+ if st.session_state.get('manual_correction', False) and canvas_result.json_data is not None:
192
+ objects = canvas_result.json_data.get("objects", [])
193
+ if objects:
194
+ # Берём последний добавленный объект как клик.
195
+ last_obj = objects[-1]
196
+ click_x = last_obj.get("left")
197
+ click_y = last_obj.get("top")
 
198
 
199
+ # Находим ближайший joint (учитывая, что координаты суставов нормализованы)
200
+ min_dist = float("inf")
201
+ selected_joint = None
202
+ for joint_name, data in skeleton_data.items():
203
+ joint_px = data['position'][0] * width
204
+ joint_py = data['position'][1] * height
205
+ dist = ((joint_px - click_x)**2 + (joint_py - click_y)**2)**0.5
206
+ if dist < min_dist:
207
+ min_dist = dist
208
+ selected_joint = joint_name
209
+ threshold = 20 # порог в пикселях для выбора сустава
210
+ if min_dist < threshold:
211
+ st.session_state.active_joint = selected_joint
212
+ st.write(f"Selected joint: **{selected_joint}**")
213
+ else:
214
+ st.write("Click closer to a joint to select it.")
215
+
216
+ # Если активный joint выбран, показываем слайдеры для его корректировки.
217
+ if st.session_state.get("active_joint"):
218
+ active_joint = st.session_state.active_joint
219
+ joints = st.session_state.get("current_landmarks", skeleton_data.copy())
220
+ st.write(f"Active joint for adjustment: **{active_joint}**")
221
+ x_pos = st.slider("Adjust X", 0.0, 1.0, float(joints[active_joint]['position'][0]), 0.01, key=f"adj_x_{active_joint}")
222
+ y_pos = st.slider("Adjust Y", 0.0, 1.0, float(joints[active_joint]['position'][1]), 0.01, key=f"adj_y_{active_joint}")
223
+ if st.button("Apply Adjustment", key=f"apply_adj_{active_joint}"):
224
+ joints[active_joint]['position'] = [x_pos, y_pos]
225
+ st.session_state.current_landmarks = joints
226
+ corrected_image = pose_detector.draw_corrected_pose(image, joints)
227
+ st.session_state.processed_image = corrected_image
228
+ corrected_rgb = cv2.cvtColor(corrected_image, cv2.COLOR_BGR2RGB)
229
+ st.image(corrected_rgb, use_column_width=True)
230
+ st.write(f"Updated {active_joint}: {joints[active_joint]['position']}")
231
+ if st.button("Save Corrections", key=f"save_adj_{active_joint}"):
232
  save_corrected_pose(db, processed_file.id, joints)
233
  st.success("Corrections saved successfully!")
234
 
 
295
  def show_instructions():
296
  with st.expander("Instructions"):
297
  st.markdown("""
298
+ 1. Upload an image/video using the file uploader.
299
+ 2. Wait for processing to complete.
300
+ 3. In the **Processed** panel, click on the preview canvas near a joint to select it.
301
+ 4. Adjust the selected joint using the sliders and click **Apply Adjustment**.
302
+ 5. Click **Save Corrections** to store the changes.
303
+ 6. Download animation data.
304
+
305
  Supported formats:
306
  - Images: JPG, PNG
307
  - Videos: MP4, GIF