remosleandre commited on
Commit ·
b46b06b
1
Parent(s): 8936391
[FIX] weight_update
Browse files- model.py +1 -1
- model_hugging_face.py +6 -1
model.py
CHANGED
|
@@ -44,7 +44,7 @@ class Architecture(nn.Module):
|
|
| 44 |
|
| 45 |
def load_model():
|
| 46 |
model = Architecture()
|
| 47 |
-
model.load_state_dict(torch.load('model_weights.pth'))
|
| 48 |
return model
|
| 49 |
|
| 50 |
def inference_model(model, input):
|
|
|
|
| 44 |
|
| 45 |
def load_model():
|
| 46 |
model = Architecture()
|
| 47 |
+
model.load_state_dict(torch.load('./model_weights.pth'))
|
| 48 |
return model
|
| 49 |
|
| 50 |
def inference_model(model, input):
|
model_hugging_face.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch
|
|
@@ -59,7 +60,11 @@ class Architecture(PreTrainedModel):
|
|
| 59 |
|
| 60 |
# Loading the model from saved weights
|
| 61 |
def load_model():
|
|
|
|
|
|
|
| 62 |
config = ArchitectureConfig()
|
| 63 |
model = Architecture(config)
|
| 64 |
-
model.load_state_dict(torch.load('model_weights.pth'))
|
| 65 |
return model
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoConfig, AutoModel
|
| 2 |
from transformers import PreTrainedModel, PretrainedConfig
|
| 3 |
import torch.nn as nn
|
| 4 |
import torch
|
|
|
|
| 60 |
|
| 61 |
# Loading the model from saved weights
|
| 62 |
def load_model():
|
| 63 |
+
AutoConfig.register("architecture", ArchitectureConfig)
|
| 64 |
+
AutoModel.register(ArchitectureConfig, Architecture)
|
| 65 |
config = ArchitectureConfig()
|
| 66 |
model = Architecture(config)
|
| 67 |
+
model.load_state_dict(torch.load('./model_weights.pth'))
|
| 68 |
return model
|
| 69 |
+
|
| 70 |
+
load_model()
|