xlstm-german-wikipedia / configuration_xlstm.py
stefan-it's picture
xlstm-config: temporarily introduce new hidden_size parameter
dbe6e99 verified
raw
history blame contribute delete
No virus
3.08 kB
import json
from typing import Any, Dict, Optional
from dacite import Config as DaciteConfig
from dacite import from_dict
from omegaconf import OmegaConf
from transformers.configuration_utils import PretrainedConfig
from xlstm import xLSTMLMModelConfig
# from .config_presets import xlstm_cfg_map
class xLSTMConfig(PretrainedConfig):
"""XLSTM configuration class.
We seperate the specific xLSTM model configuration
from the rest due to the heavy nesting of the configuration.
"""
model_type = "xlstm"
def __init__(
self, vocab_size: int = 32000, config: Optional[Dict[str, Any]] = None, **kwargs
):
super().__init__(**kwargs)
cfg = OmegaConf.create(config)
cfg["vocab_size"] = vocab_size
for key, value in kwargs.items():
cfg[key] = value
self._xlstm_config = cfg
self.vocab_size = vocab_size
self.embedding_dim = cfg.get("embedding_dim")
self.context_length = cfg.get("context_length")
self.hidden_size = cfg.get("embedding_dim")
def to_xlstm_config(self):
return from_dict(
data_class=xLSTMLMModelConfig,
data=OmegaConf.to_container(self._xlstm_config),
config=DaciteConfig(strict=True),
)
def to_dict(self) -> Dict[str, Any]:
"""
Converts the configuration to a dictionary for serialization.
"""
output = super().to_dict()
output["_xlstm_config"] = OmegaConf.to_container(
self._xlstm_config, resolve=True
)
relevant_keys = [
"vocab_size",
"embedding_dim",
"context_length",
"torch_dtype",
"_xlstm_config",
"transformers_version",
"architectures",
"model_type",
]
output_ = output.copy()
for key in output.keys():
if key not in relevant_keys:
output_.pop(key)
return output_
@classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
"""
Creates a configuration instance from a dictionary.
"""
xlstm_config = config_dict.pop("_xlstm_config")
vocab_size = config_dict.pop("vocab_size")
config = cls(vocab_size=vocab_size, config=xlstm_config)
if "auto_map" in config_dict and config_dict["auto_map"]:
setattr(config, "auto_map", config_dict.pop("auto_map"))
# breakpoint()
# config.xlstm_config = xlstm_config
if "return_unused_kwargs" in kwargs and kwargs["return_unused_kwargs"]:
return config, {}
return config
def to_json_string(self, *args, **kwargs) -> str:
"""
Serializes the instance to a JSON string.
"""
return json.dumps(self.to_dict(), indent=2)
@classmethod
def from_json_string(cls, json_string: str):
"""
Deserializes the instance from a JSON string.
"""
config_dict = json.loads(json_string)
return cls.from_dict(config_dict)