Transformers.HuggingFace
Module for loading pre-trained model from HuggingFace.
We provide a set of API to download & load a pretrain model from huggingface hub. This is mostly manually done, so we only have a small set of available models. The most practical way to check if a model is available in Transformers is to run the HuggingFaceValidation
code in the example folder, which use PyCall.jl
to load the model in both Python and Julia. Open issues/PRs if you find a model you want is not supported here.
There are basically 3 main api for loading the model, HuggingFace.load_config
, HuggingFace.load_tokenizer
, HuggingFace.load_model
. These are the underlying function of the HuggingFace.@hgf_str
macro. You can get a better control of the loading process.
We can load a specific config of a specific model, no matter it's actually supported by Transformers.jl.
julia> load_config("google/pegasus-xsum")
Transformers.HuggingFace.HGFConfig{:pegasus, JSON3.Object{Vector{UInt8}, Vector{UInt64}}, Dict{Symbol, Any}} with 45 entries:
:use_cache => true
:d_model => 1024
:scale_embedding => true
:add_bias_logits => false
:static_position_embeddings => true
:encoder_attention_heads => 16
:num_hidden_layers => 16
:encoder_layerdrop => 0
:num_beams => 8
:max_position_embeddings => 512
:model_type => "pegasus"
⋮ => ⋮
This would give you all value available in the downloaded configuration file. This might be enough for a some model, but there are other model that use the default value hard coded in their python code.
Sometime you would want to add/overwrite some of the value. This can be done be calling HGFConfig(old_config; key_to_update = new_value, ...)
. These is used primary for customizing model loading. For example, you can load a bert-base-cased
model for sequence classification task. However, if you directly load the model:
julia> bert_model = hgf"bert-base-cased:ForSequenceClassification";
julia> bert_model.cls.layer
Dense(W = (768, 2), b = true)
The model is default creating model for 2 class of label. So you would need to load the config and update the field about number of labels and create the model with the new config:
julia> bertcfg = load_config("bert-base-cased");
julia> bertcfg.num_labels
2
julia> mycfg = HuggingFace.HGFConfig(bertcfg; num_labels = 3);
julia> mycfg.num_labels
3
julia> _bert_model = load_model("bert-base-cased", :ForSequenceClassification; config = mycfg);
julia> _bert_model.cls.layer
Dense(W = (768, 3), b = true)
All config field name follow the same name as huggingface, so you might need to read their document for what is available. However, not every configuration work in Transformers.jl. It's better to check the source src/huggingface/implementation
. All supported models would need to overload the load_model
and provided an implementation in Julia to be workable.
For the tokenizer, load_tokenizer
is basically the same as calling with @hgf_str
. Currently providing customized config doesn't change much stuff. The tokenizer might also work for unsupported model because some serialize the whole tokenizer object, but not every model does that or they use something not covered by our implementation.
API Reference
Transformers.HuggingFace.@hgf_str
— Macro`hgf"<model-name>:<item>"`
Get item
from model-name
. This will ensure the required data are downloaded. item
can be "config", "tokenizer", and model related like "Model", or "ForMaskedLM", etc. Use get_model_type
to see what model/task are supported. If item
is omitted, return a Tuple
of <model-name>:tokenizer
and <model-name>:model
.
Transformers.HuggingFace.HGFConfig
— TypeHGFConfig{model_type} <: AbstractDict{Symbol, Any}
The type for holding the configuration for huggingface model model_type
.
HGFConfig(base_cfg::HGFConfig; kwargs...)
Return a new HGFConfig
object for the same model_type
with fields updated with kwargs
.
Example
julia> bertcfg = load_config("bert-base-cased");
julia> bertcfg.num_labels
2
julia> mycfg = HuggingFace.HGFConfig(bertcfg; num_labels = 3);
julia> mycfg.num_labels
3
Extended help
Each HGFConfig
has a pre-defined set of type-dependent default field values and some field name aliases. For example, (cfg::HGFConfig{:gpt2}).hidden_size
is an alias of (cfg::HGFConfig{:gpt2}).n_embd
. Using propertynames
, hasproperty
, getproperty
, getindex
will access the default field values if the key is not present in the loaded configuration. On the other hand, using length
, keys
, haskey
, get
, iterate
will not interact with the default values (while the name aliases still work).
julia> fakegpt2cfg = HuggingFace.HGFConfig{:gpt2}((a=3,b=5))
Transformers.HuggingFace.HGFConfig{:gpt2, @NamedTuple{a::Int64, b::Int64}, Nothing} with 2 entries:
:a => 3
:b => 5
julia> myfakegpt2cfg = HuggingFace.HGFConfig(fakegpt2cfg; hidden_size = 7)
Transformers.HuggingFace.HGFConfig{:gpt2, @NamedTuple{a::Int64, b::Int64}, @NamedTuple{n_embd::Int64}} with 3 entries:
:a => 3
:b => 5
:n_embd => 7
julia> myfakegpt2cfg[:hidden_size] == myfakegpt2cfg.hidden_size == myfakegpt2cfg.n_embd
true
julia> myfakegpt2cfg.n_layer
12
julia> get(myfakegpt2cfg, :n_layer, "NOT_FOUND")
"NOT_FOUND"
Transformers.HuggingFace.get_model_type
— Functionget_model_type(model_type)
See the list of supported model type of given model. For example, use get_mdoel_type(:gpt2)
to see all model/task that gpt2
support. The keys
of the returned NamedTuple
are all possible task which can be used in load_model
or @hgf_str
.
Example
julia> HuggingFace.get_model_type(:gpt2)
(model = Transformers.HuggingFace.HGFGPT2Model, lmheadmodel = Transformers.HuggingFace.HGFGPT2LMHeadModel)
Transformers.HuggingFace.get_state_dict
— Functionget_state_dict(model)
Get the state_dict of the model.
Transformers.HuggingFace.load_config
— Methodload_config(model_name; local_files_only = false, cache = true)
Load the configuration file of model_name
from huggingface hub. By default, this function would check if model_name
exists on huggingface hub, download the configuration file (and cache it if cache
is set), and then load and return the config::HGFConfig
. If local_files_only = false
, it would check whether the configuration file is up-to-date and update if not (and thus require network access every time it is called). By setting local_files_only = true
, it would try to find the files from the cache directly and error out if not found. For managing the caches, see the HuggingFaceApi.jl
package. This function would require the configuration file has a field about the model_type
, if not, use load_config(model_type, HuggingFace.load_config_dict(model_name; local_files_only, cache))
with model_type
manually provided.
See also: HGFConfig
Example
julia> load_config("bert-base-cased")
Transformers.HuggingFace.HGFConfig{:bert, JSON3.Object{Vector{UInt8}, Vector{UInt64}}, Nothing} with 19 entries:
:architectures => ["BertForMaskedLM"]
:attention_probs_dropout_prob => 0.1
:gradient_checkpointing => false
:hidden_act => "gelu"
:hidden_dropout_prob => 0.1
:hidden_size => 768
:initializer_range => 0.02
:intermediate_size => 3072
:layer_norm_eps => 1.0e-12
:max_position_embeddings => 512
:model_type => "bert"
:num_attention_heads => 12
:num_hidden_layers => 12
:pad_token_id => 0
:position_embedding_type => "absolute"
:transformers_version => "4.6.0.dev0"
:type_vocab_size => 2
:use_cache => true
:vocab_size => 28996
Transformers.HuggingFace.load_config
— Methodload_config(model_type, cfg)
Load cfg
as model_type
. This is used for manually load a config when model_type
is not specified in the config. model_type
is a Symbol
of the model type like :bert
, :gpt2
, :t5
, etc.
Transformers.HuggingFace.load_hgf_pretrained
— Methodload_hgf_pretrained(name)
The underlying function of @hgf_str
.
Transformers.HuggingFace.load_model
— Functionload_model([model_type::Symbol,] model_name, task = :model [, state_dict];
trainmode = false, config = nothing, local_files_only = false, cache = true)
Load the model of model_name
for task
. This function would load the state_dict
of model_name
and build a new model according to config
, task
, and the state_dict
. local_files_only
and cache
kwargs would be pass directly to both load_state_dict
and load_config
if not provided. This function would require the configuration file has a field about the model_type
, if not, use load_model(model_type, model_name, task; kwargs...)
with model_type
manually provided. trainmode = false
would disable all dropouts. The state_dict
can be directly provided, this is used when you want to create a new model with the state_dict
in hand. Use get_model_type
to see what task
is available.
See also: get_model_type
, load_state_dict
, load_config
, HGFConfig
Transformers.HuggingFace.load_model
— Methodload_model(::Type{T}, config, state_dict = OrderedDict())
Create a new model of T
according to config
and state_dict
. missing parameter would initialized according to config
. Set the JULIA_DEBUG=Transformers
environment variable to see what parameters are missing.
Transformers.HuggingFace.load_state_dict
— Methodload_state_dict(model_name; local_files_only = false, force_format = :auto, cache = true)
Load the state_dict
from the given model_name
from huggingface hub. By default, this function would check if model_name
exists on huggingface hub, download the model file (and cache it if cache
is set), and then load and return the state_dict
. If local_files_only = false
, it would check whether the model file is up-to-date and update if not (and thus require network access every time it is called). By setting local_files_only = true
, it would try to find the files from the cache directly and error out if not found. For managing the caches, see the HuggingFaceApi.jl
package. If force_format
is :auto
it will automatically selects the format from which the weights will be loaded. If force_format
is :pickle
or :safetensor
, it will prefer relevant file.
Transformers.HuggingFace.load_tokenizer
— Functionload_tokenizer(model_name; config = nothing, local_files_only = false, cache = true)
Load the text encoder of model_name
from huggingface hub. By default, this function would check if model_name
exists on huggingface hub, download all required files for this text encoder (and cache these files if cache
is set), and then load and return the text encoder. If local_files_only = false
, it would check whether all cached files are up-to-date and update if not (and thus require network access every time it is called). By setting local_files_only = true
, it would try to find the files from the cache directly and error out if not found. For managing the caches, see the HuggingFaceApi.jl
package.
Example
julia> load_tokenizer("t5-small")
T5TextEncoder(
├─ TextTokenizer(MatchTokenization(PrecompiledNormalizer(WordReplaceNormalizer(UnigramTokenization(EachSplitTokenization(splitter = isspace), unigram = Unigram(vocab_size = 32100, unk = <unk>)), pattern = r"^(?!▁)(.*)$" => s"▁"), precompiled = PrecompiledNorm(...)), 103 patterns)),
├─ vocab = Vocab{String, SizedArray}(size = 32100, unk = <unk>, unki = 3),
├─ endsym = </s>,
├─ padsym = <pad>,
└─ process = Pipelines:
╰─ target[token] := TextEncodeBase.nestedcall(string_getvalue, source)
╰─ target[token] := Transformers.TextEncoders.grouping_sentence(target.token)
╰─ target[(token, segment)] := SequenceTemplate{String}(Input[1]:<type=1> </s>:<type=1> (Input[2]:<type=1> </s>:<type=1>)...)(target.token)
╰─ target[attention_mask] := (NeuralAttentionlib.LengthMask ∘ Transformers.TextEncoders.getlengths(nothing))(target.token)
╰─ target[token] := TextEncodeBase.trunc_and_pad(nothing, <pad>, tail, tail)(target.token)
╰─ target[token] := TextEncodeBase.nested2batch(target.token)
╰─ target := (target.token, target.attention_mask)
)
Transformers.HuggingFace.save_config
— Methodsave_config(model_name, config; path = pwd(), config_name = CONFIG_NAME, force = false)
Save the config
at <path>/<model_name>/<config_name>
. This would error out if the file already exists but force
not set.
Transformers.HuggingFace.save_model
— Methodsave_model(model_name, model; path = pwd(), weight_name = PYTORCH_WEIGHTS_NAME, force = false)
save the model
statedict at `<path>/<modelname>/<weight_name>. This would error out if the file already exists but
force` not set.
Transformers.HuggingFace.state_dict_to_namedtuple
— Methodstate_dict_to_namedtuple(state_dict)
convert state_dict into nested NamedTuple
.