Julien Blanchon commited on
Commit
63a6b63
ยท
1 Parent(s): 1ce9d31
Files changed (2) hide show
  1. app.py +435 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env -S uv run --script
2
+ # /// script
3
+ # requires-python = ">=3.11"
4
+ # dependencies = [
5
+ # "requests<3",
6
+ # "pillow",
7
+ # "opencv-python",
8
+ # "pyboy",
9
+ # "huggingface-hub",
10
+ # "gradio",
11
+ # "numpy",
12
+ # "nitrogen @ git+https://github.com/MineDojo/NitroGen.git@main",
13
+ # ]
14
+ # [tool.uv]
15
+ # exclude-newer = "2025-12-22T00:00:00Z"
16
+ # ///
17
+ """
18
+ Unified Gradio app for NitroGen Pokemon Red player with real-time streaming
19
+ Combines model inference and PyBoy gameplay in a single interface
20
+ """
21
+ import gradio as gr
22
+ from pathlib import Path
23
+ import cv2
24
+ import numpy as np
25
+ from PIL import Image
26
+ from pyboy import PyBoy
27
+ from pyboy.utils import WindowEvent
28
+ import time
29
+ import tempfile
30
+ import requests
31
+ from huggingface_hub import HfFileSystem
32
+
33
+ from nitrogen.inference_session import InferenceSession
34
+ from nitrogen.shared import PATH_REPO, BUTTON_ACTION_TOKENS
35
+
36
+ ROM_URL = "https://github.com/hxh-robb/pokemon-roms/raw/refs/heads/master/ROM/Pokemon%20-%20Red%20Version%20(USA,%20Europe).gb"
37
+ STATE_PATH = "./init.state"
38
+
39
+ # Game Boy button mapping
40
+ GB_BUTTONS = {
41
+ "A": WindowEvent.PRESS_BUTTON_A,
42
+ "B": WindowEvent.PRESS_BUTTON_B,
43
+ "START": WindowEvent.PRESS_BUTTON_START,
44
+ "SELECT": WindowEvent.PRESS_BUTTON_SELECT,
45
+ "UP": WindowEvent.PRESS_ARROW_UP,
46
+ "DOWN": WindowEvent.PRESS_ARROW_DOWN,
47
+ "LEFT": WindowEvent.PRESS_ARROW_LEFT,
48
+ "RIGHT": WindowEvent.PRESS_ARROW_RIGHT,
49
+ }
50
+
51
+ GB_BUTTONS_RELEASE = {
52
+ "A": WindowEvent.RELEASE_BUTTON_A,
53
+ "B": WindowEvent.RELEASE_BUTTON_B,
54
+ "START": WindowEvent.RELEASE_BUTTON_START,
55
+ "SELECT": WindowEvent.RELEASE_BUTTON_SELECT,
56
+ "UP": WindowEvent.RELEASE_ARROW_UP,
57
+ "DOWN": WindowEvent.RELEASE_ARROW_DOWN,
58
+ "LEFT": WindowEvent.RELEASE_ARROW_LEFT,
59
+ "RIGHT": WindowEvent.RELEASE_ARROW_RIGHT,
60
+ }
61
+
62
+ def preprocess_img(frame):
63
+ """Convert Game Boy frame to 256x256 RGB PIL Image for model input"""
64
+ if isinstance(frame, Image.Image):
65
+ frame = np.array(frame)
66
+
67
+ if len(frame.shape) == 2:
68
+ frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
69
+ elif frame.shape[2] == 4:
70
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB)
71
+
72
+ frame_resized = cv2.resize(frame, (256, 256), interpolation=cv2.INTER_AREA)
73
+ return Image.fromarray(frame_resized)
74
+
75
+ def gamepad_to_gameboy_buttons(pred, button_threshold=0.5, joystick_threshold=0.3):
76
+ """Convert model's gamepad prediction to Game Boy button presses"""
77
+ j_left, j_right, buttons = pred["j_left"], pred["j_right"], pred["buttons"]
78
+ pressed_buttons = []
79
+
80
+ if len(buttons) == 0:
81
+ return pressed_buttons
82
+
83
+ button_vals = buttons[0]
84
+
85
+ if len(button_vals) < len(BUTTON_ACTION_TOKENS):
86
+ return pressed_buttons
87
+
88
+ # D-Pad mapping (indices 1-4)
89
+ if button_vals[1] > button_threshold:
90
+ pressed_buttons.append("DOWN")
91
+ if button_vals[2] > button_threshold:
92
+ pressed_buttons.append("LEFT")
93
+ if button_vals[3] > button_threshold:
94
+ pressed_buttons.append("RIGHT")
95
+ if button_vals[4] > button_threshold:
96
+ pressed_buttons.append("UP")
97
+
98
+ # Joystick fallback if no D-pad pressed
99
+ if not any(b in pressed_buttons for b in ["UP", "DOWN", "LEFT", "RIGHT"]):
100
+ if len(j_left) > 0:
101
+ xl, yl = j_left[0]
102
+ if abs(xl) > joystick_threshold or abs(yl) > joystick_threshold:
103
+ if abs(xl) > abs(yl):
104
+ if xl > joystick_threshold:
105
+ pressed_buttons.append("RIGHT")
106
+ elif xl < -joystick_threshold:
107
+ pressed_buttons.append("LEFT")
108
+ else:
109
+ if yl > joystick_threshold:
110
+ pressed_buttons.append("DOWN")
111
+ elif yl < -joystick_threshold:
112
+ pressed_buttons.append("UP")
113
+
114
+ # Action buttons
115
+ if button_vals[18] > button_threshold: # SOUTH -> A
116
+ pressed_buttons.append("A")
117
+ if button_vals[5] > button_threshold: # EAST -> B
118
+ pressed_buttons.append("B")
119
+ if button_vals[19] > button_threshold: # START
120
+ pressed_buttons.append("START")
121
+ if button_vals[0] > button_threshold: # BACK -> SELECT
122
+ pressed_buttons.append("SELECT")
123
+
124
+ # Alternative mappings
125
+ if button_vals[10] > button_threshold and "A" not in pressed_buttons: # NORTH -> A
126
+ pressed_buttons.append("A")
127
+ if button_vals[20] > button_threshold and "B" not in pressed_buttons: # WEST -> B
128
+ pressed_buttons.append("B")
129
+ if button_vals[7] > button_threshold and "A" not in pressed_buttons: # LEFT_SHOULDER -> A
130
+ pressed_buttons.append("A")
131
+ if button_vals[14] > button_threshold and "B" not in pressed_buttons: # RIGHT_SHOULDER -> B
132
+ pressed_buttons.append("B")
133
+
134
+ return pressed_buttons
135
+
136
+ def play_pokemon(
137
+ cfg_scale: float,
138
+ context_length: int,
139
+ max_steps: int,
140
+ frame_skip: int,
141
+ button_threshold: float,
142
+ display_every: int,
143
+ update_delay: float
144
+ ):
145
+ """Generator that yields frames while playing Pokemon Red"""
146
+
147
+ # Download ROM from URL
148
+ yield None, "โณ Downloading ROM file...", None
149
+ try:
150
+ temp_dir = Path(tempfile.gettempdir())
151
+ rom_path = temp_dir / "PokemonRed.gb"
152
+
153
+ # Download ROM if not already cached
154
+ if not rom_path.exists():
155
+ response = requests.get(ROM_URL, stream=True)
156
+ response.raise_for_status()
157
+
158
+ with open(rom_path, 'wb') as f:
159
+ for chunk in response.iter_content(chunk_size=8192):
160
+ f.write(chunk)
161
+
162
+ yield None, "โœ… ROM downloaded successfully", None
163
+ time.sleep(0.5)
164
+ else:
165
+ yield None, "โœ… Using cached ROM", None
166
+ time.sleep(0.3)
167
+ except Exception as e:
168
+ yield None, f"โŒ Error downloading ROM: {str(e)}", None
169
+ return
170
+
171
+ # Download checkpoint from HuggingFace using HfFileSystem
172
+ yield None, "โณ Downloading checkpoint from nvidia/NitroGen...", None
173
+ try:
174
+ ckpt_path = temp_dir / "ng.pt"
175
+
176
+ # Download checkpoint from HuggingFace Hub if not already cached
177
+ if not ckpt_path.exists():
178
+ hffs = HfFileSystem()
179
+ hffs.get_file("nvidia/NitroGen/ng.pt", str(ckpt_path))
180
+
181
+ if not ckpt_path.exists():
182
+ yield None, "โŒ Failed to download checkpoint from HuggingFace", None
183
+ return
184
+
185
+ yield None, "โœ… Checkpoint downloaded successfully", None
186
+ time.sleep(0.5)
187
+ else:
188
+ yield None, "โœ… Using cached checkpoint", None
189
+ time.sleep(0.3)
190
+ except Exception as e:
191
+ yield None, f"โŒ Error downloading checkpoint: {str(e)}", None
192
+ return
193
+
194
+ # Initialize inference session
195
+ yield None, "โณ Initializing inference session...", None
196
+ session = InferenceSession.from_ckpt(
197
+ str(ckpt_path),
198
+ cfg_scale=cfg_scale,
199
+ context_length=context_length
200
+ )
201
+ session.reset()
202
+
203
+ # Initialize PyBoy
204
+ pyboy = PyBoy(str(rom_path), window="null")
205
+ pyboy.set_emulation_speed(0) # Unlimited speed
206
+
207
+ # Load save state if it exists
208
+ state_path = Path(STATE_PATH)
209
+ if state_path.exists():
210
+ with open(state_path, "rb") as f:
211
+ pyboy.load_state(f)
212
+ yield None, f"โœ… Loaded save state: {STATE_PATH}", None
213
+ time.sleep(0.3)
214
+ else:
215
+ yield None, f"โš ๏ธ Save state not found: {STATE_PATH} (starting fresh)", None
216
+ time.sleep(0.3)
217
+
218
+ # Display settings
219
+ width, height = 640, 576
220
+ step_count = 0
221
+
222
+ # Button timing: Press button briefly (4 frames), then release and wait
223
+ # This prevents holding buttons for too long (which would cause repeated movement)
224
+ # E.g., with frame_skip=16: press DOWN for 4 frames, release, wait 12 frames
225
+ # Result: Character moves 1 tile down, not 16 tiles
226
+ button_hold_frames = 4
227
+
228
+ try:
229
+ while step_count < max_steps:
230
+ # Get screen and predict
231
+ screen = pyboy.screen.image
232
+ obs_processed = preprocess_img(screen)
233
+ pred = session.predict(obs_processed)
234
+
235
+ # Convert to Game Boy buttons
236
+ buttons_to_press = gamepad_to_gameboy_buttons(pred, button_threshold)
237
+
238
+ # Press buttons
239
+ for btn in buttons_to_press:
240
+ pyboy.send_input(GB_BUTTONS[btn])
241
+
242
+ # Hold buttons for a few frames (so action registers)
243
+ pyboy.tick(button_hold_frames, render=False)
244
+
245
+ # Release buttons
246
+ for btn in buttons_to_press:
247
+ pyboy.send_input(GB_BUTTONS_RELEASE[btn])
248
+
249
+ # Tick remaining frames to complete the frame_skip cycle
250
+ remaining_frames = frame_skip - button_hold_frames
251
+ if remaining_frames > 1:
252
+ pyboy.tick(remaining_frames - 1, render=False)
253
+ if remaining_frames > 0:
254
+ pyboy.tick() # Final tick with render
255
+ else:
256
+ pyboy.tick() # Render at least once
257
+
258
+ # Yield display update at specified frequency
259
+ if step_count % display_every == 0:
260
+ # Get frame (lightweight - no text overlay)
261
+ screen_np = pyboy.screen.ndarray
262
+ if screen_np.shape[2] == 4:
263
+ screen_np = screen_np[:, :, :3]
264
+
265
+ # Simple resize
266
+ frame_display = cv2.resize(
267
+ screen_np,
268
+ (width, height),
269
+ interpolation=cv2.INTER_NEAREST
270
+ )
271
+
272
+ # Create action info
273
+ action_info = f"**Step {step_count}/{max_steps}**\n\n"
274
+ action_info += f"๐ŸŽฎ **Buttons:** {', '.join(buttons_to_press) if buttons_to_press else 'None'}\n\n"
275
+ action_info += f"โšก **Speed:** {frame_skip}x frame skip\n\n"
276
+ action_info += f"๐Ÿ“Š **Progress:** {step_count/max_steps*100:.1f}%"
277
+
278
+ # Create stats info
279
+ stats_info = f"**Inference Details**\n\n"
280
+ if len(pred.get("buttons", [])) > 0:
281
+ button_vals = pred["buttons"][0]
282
+ active_buttons = [
283
+ f"{BUTTON_ACTION_TOKENS[i]}: {button_vals[i]:.2f}"
284
+ for i in range(min(len(button_vals), len(BUTTON_ACTION_TOKENS)))
285
+ if button_vals[i] > button_threshold
286
+ ]
287
+ if active_buttons:
288
+ stats_info += "**Active Predictions:**\n"
289
+ stats_info += "\n".join(f"- {btn}" for btn in active_buttons[:5])
290
+ else:
291
+ stats_info += "No buttons above threshold"
292
+
293
+ # Yield frame and info (no encoding overhead)
294
+ yield frame_display, action_info, stats_info
295
+
296
+ # Delay to allow Gradio to load images properly
297
+ time.sleep(update_delay)
298
+
299
+ step_count += 1
300
+
301
+ finally:
302
+ # Stop emulator
303
+ pyboy.stop()
304
+
305
+ # Create Gradio interface
306
+ with gr.Blocks(title="NitroGen Pokemon Red Player") as app:
307
+ gr.Markdown("# ๐ŸŽฎ NitroGen Pokemon Red Player")
308
+ gr.Markdown("Stream Pokemon Red gameplay powered by NitroGen AI model")
309
+
310
+ with gr.Row():
311
+ with gr.Column(scale=1):
312
+ gr.Markdown("### ๐Ÿค– Model Settings")
313
+ gr.Markdown("**Model:** nvidia/NitroGen (ng.pt) - automatically downloaded from HuggingFace Hub")
314
+ gr.Markdown("**ROM:** Automatically downloaded from configured URL")
315
+ gr.Markdown(f"**Save State:** {STATE_PATH}")
316
+ cfg_input = gr.Slider(
317
+ label="CFG Scale",
318
+ minimum=0.0,
319
+ maximum=3.0,
320
+ value=1.0,
321
+ step=0.1,
322
+ info="Classifier-free guidance scale"
323
+ )
324
+ ctx_input = gr.Slider(
325
+ label="Context Length",
326
+ minimum=1,
327
+ maximum=32,
328
+ value=1,
329
+ step=1,
330
+ info="Number of past frames to use"
331
+ )
332
+
333
+ gr.Markdown("### โš™๏ธ Playback Settings")
334
+ max_steps_input = gr.Slider(
335
+ label="Max Steps",
336
+ minimum=100,
337
+ maximum=10000,
338
+ value=1000,
339
+ step=100,
340
+ info="Maximum inference steps"
341
+ )
342
+ frame_skip_input = gr.Slider(
343
+ label="Frame Skip",
344
+ minimum=1,
345
+ maximum=64,
346
+ value=16,
347
+ step=1,
348
+ info="Emulator frames per inference"
349
+ )
350
+ button_threshold_input = gr.Slider(
351
+ label="Button Threshold",
352
+ minimum=0.0,
353
+ maximum=1.0,
354
+ value=0.5,
355
+ step=0.05,
356
+ info="Threshold for button activation"
357
+ )
358
+ display_every_input = gr.Slider(
359
+ label="Display Every N Steps",
360
+ minimum=1,
361
+ maximum=10,
362
+ value=1,
363
+ step=1,
364
+ info="Update display frequency (1=every step, higher=faster but less frequent)"
365
+ )
366
+ update_delay_input = gr.Slider(
367
+ label="Update Delay (seconds)",
368
+ minimum=0.1,
369
+ maximum=3.0,
370
+ value=1.0,
371
+ step=0.1,
372
+ info="Wait time after each display update (higher=more time for image to load)"
373
+ )
374
+
375
+ start_btn = gr.Button("๐Ÿš€ Start Playing", variant="primary", size="lg")
376
+
377
+ with gr.Column(scale=2):
378
+ image_output = gr.Image(
379
+ label="Game Stream",
380
+ height=600,
381
+ interactive=False
382
+ )
383
+
384
+ with gr.Row():
385
+ with gr.Column():
386
+ action_output = gr.Markdown(
387
+ label="Actions",
388
+ value="**Waiting to start...**"
389
+ )
390
+ with gr.Column():
391
+ stats_output = gr.Markdown(
392
+ label="Statistics",
393
+ value="**No data yet**"
394
+ )
395
+
396
+ gr.Markdown("""
397
+ ### ๐Ÿ“ Instructions
398
+ 1. Adjust playback settings as needed
399
+ 2. Click "Start Playing" to begin streaming
400
+ 3. Game frames update in real-time with actions
401
+
402
+ **Automatic Setup:**
403
+ - **Model**: nvidia/NitroGen checkpoint (ng.pt) from HuggingFace Hub
404
+ - **ROM**: Downloaded from configured URL
405
+ - **Save State**: Loaded from `./init.state` if available
406
+ - Model and ROM are cached in temp directory for faster subsequent runs
407
+
408
+ **Tips:**
409
+ - **Display Every N Steps**: 1 = update every step, higher = faster but less frequent
410
+ - **Update Delay**: 1s default gives images time to load, reduce for faster updates
411
+ - **Frame Skip**: 16 = game runs 16 frames per inference (faster gameplay)
412
+ """)
413
+
414
+ # Connect the button to the play function
415
+ start_btn.click(
416
+ fn=play_pokemon,
417
+ inputs=[
418
+ cfg_input,
419
+ ctx_input,
420
+ max_steps_input,
421
+ frame_skip_input,
422
+ button_threshold_input,
423
+ display_every_input,
424
+ update_delay_input
425
+ ],
426
+ outputs=[image_output, action_output, stats_output]
427
+ )
428
+
429
+ if __name__ == "__main__":
430
+ app.launch(
431
+ server_name="0.0.0.0",
432
+ server_port=7860,
433
+ share=False
434
+ )
435
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ requests
2
+ pillow
3
+ opencv-python
4
+ pyboy
5
+ huggingface-hub
6
+ gradio
7
+ numpy
8
+ nitrogen @ git+https://github.com/MineDojo/NitroGen.git@main