Transformers.HuggingFace

Module for loading pre-trained model from HuggingFace.

Info

We provide a set of API to download & load a pretrain model from huggingface hub. This is mostly manually done, so we only have a small set of available models. The most practical way to check if a model is available in Transformers is to run the HuggingFaceValidation code in the example folder, which use PyCall.jl to load the model in both Python and Julia. Open issues/PRs if you find a model you want is not supported here.

There are basically 3 main api for loading the model, HuggingFace.load_config, HuggingFace.load_tokenizer, HuggingFace.load_model. These are the underlying function of the HuggingFace.@hgf_str macro. You can get a better control of the loading process.

We can load a specific config of a specific model, no matter it's actually supported by Transformers.jl.

julia> load_config("google/pegasus-xsum")
Transformers.HuggingFace.HGFConfig{:pegasus, JSON3.Object{Vector{UInt8}, Vector{UInt64}}, Dict{Symbol, Any}} with 45 entries:
  :use_cache                       => true
  :d_model                         => 1024
  :scale_embedding                 => true
  :add_bias_logits                 => false
  :static_position_embeddings      => true
  :encoder_attention_heads         => 16
  :num_hidden_layers               => 16
  :encoder_layerdrop               => 0
  :num_beams                       => 8
  :max_position_embeddings         => 512
  :model_type                      => "pegasus"
  ⋮                                => ⋮

This would give you all value available in the downloaded configuration file. This might be enough for a some model, but there are other model that use the default value hard coded in their python code.

Sometime you would want to add/overwrite some of the value. This can be done be calling HGFConfig(old_config; key_to_update = new_value, ...). These is used primary for customizing model loading. For example, you can load a bert-base-cased model for sequence classification task. However, if you directly load the model:

julia> bert_model = hgf"bert-base-cased:ForSequenceClassification";

julia> bert_model.cls.layer
Dense(W = (768, 2), b = true)

The model is default creating model for 2 class of label. So you would need to load the config and update the field about number of labels and create the model with the new config:

julia> bertcfg = load_config("bert-base-cased");

julia> bertcfg.num_labels
2

julia> mycfg = HuggingFace.HGFConfig(bertcfg; num_labels = 3);

julia> mycfg.num_labels
3

julia> _bert_model = load_model("bert-base-cased", :ForSequenceClassification; config = mycfg);

julia> _bert_model.cls.layer
Dense(W = (768, 3), b = true)

All config field name follow the same name as huggingface, so you might need to read their document for what is available. However, not every configuration work in Transformers.jl. It's better to check the source src/huggingface/implementation. All supported models would need to overload the load_model and provided an implementation in Julia to be workable.

For the tokenizer, load_tokenizer is basically the same as calling with @hgf_str. Currently providing customized config doesn't change much stuff. The tokenizer might also work for unsupported model because some serialize the whole tokenizer object, but not every model does that or they use something not covered by our implementation.

API Reference

Transformers.HuggingFace.@hgf_strMacro
`hgf"<model-name>:<item>"`

Get item from model-name. This will ensure the required data are downloaded. item can be "config", "tokenizer", and model related like "Model", or "ForMaskedLM", etc. Use get_model_type to see what model/task are supported. If item is omitted, return a Tuple of <model-name>:tokenizer and <model-name>:model.

source
Transformers.HuggingFace.HGFConfigType
HGFConfig{model_type}

The type for holding the configuration for huggingface model model_type.

HGFConfig(base_cfg::HGFConfig; kwargs...)

Return a new HGFConfig object for the same model_type with fields updated with kwargs.

Example

julia> bertcfg = load_config("bert-base-cased");

julia> bertcfg.num_labels
2

julia> mycfg = HuggingFace.HGFConfig(bertcfg; num_labels = 3);

julia> mycfg.num_labels
3
source
Transformers.HuggingFace.get_model_typeFunction

get_model_type(model_type)

See the list of supported model type of given model. For example, use get_mdoel_type(:gpt2) to see all model/task that gpt2 support. The keys of the returned NamedTuple are all possible task which can be used in load_model or @hgf_str.

Example

julia> HuggingFace.get_model_type(:gpt2)
(model = Transformers.HuggingFace.HGFGPT2Model, lmheadmodel = Transformers.HuggingFace.HGFGPT2LMHeadModel)
source
Transformers.HuggingFace.load_configMethod
load_config(model_name; local_files_only = false, cache = true)

Load the configuration file of model_name from huggingface hub. By default, this function would check if model_name exists on huggingface hub, download the configuration file (and cache it if cache is set), and then load and return the config::HGFConfig. If local_files_only = false, it would check whether the configuration file is up-to-date and update if not (and thus require network access every time it is called). By setting local_files_only = true, it would try to find the files from the cache directly and error out if not found. For managing the caches, see the HuggingFaceApi.jl package. This function would require the configuration file has a field about the model_type, if not, use load_config(model_type, HuggingFace.load_config_dict(model_name; local_files_only, cache)) with model_type manually provided.

See also: HGFConfig

Example

julia> load_config("bert-base-cased")
Transformers.HuggingFace.HGFConfig{:bert, JSON3.Object{Vector{UInt8}, Vector{UInt64}}, Nothing} with 19 entries:
  :architectures                => ["BertForMaskedLM"]
  :attention_probs_dropout_prob => 0.1
  :gradient_checkpointing       => false
  :hidden_act                   => "gelu"
  :hidden_dropout_prob          => 0.1
  :hidden_size                  => 768
  :initializer_range            => 0.02
  :intermediate_size            => 3072
  :layer_norm_eps               => 1.0e-12
  :max_position_embeddings      => 512
  :model_type                   => "bert"
  :num_attention_heads          => 12
  :num_hidden_layers            => 12
  :pad_token_id                 => 0
  :position_embedding_type      => "absolute"
  :transformers_version         => "4.6.0.dev0"
  :type_vocab_size              => 2
  :use_cache                    => true
  :vocab_size                   => 28996
source
Transformers.HuggingFace.load_configMethod
load_config(model_type, cfg)

Load cfg as model_type. This is used for manually load a config when model_type is not specified in the config. model_type is a Symbol of the model type like :bert, :gpt2, :t5, etc.

source
Transformers.HuggingFace.load_modelFunction
load_model([model_type::Symbol,] model_name, task = :model [, state_dict];
           trainmode = false, config = nothing, local_files_only = false, cache = true)

Load the model of model_name for task. This function would load the state_dict of model_name and build a new model according to config, task, and the state_dict. local_files_only and cache kwargs would be pass directly to both load_state_dict and load_config if not provided. This function would require the configuration file has a field about the model_type, if not, use load_model(model_type, model_name, task; kwargs...) with model_type manually provided. trainmode = false would disable all dropouts. The state_dict can be directly provided, this is used when you want to create a new model with the state_dict in hand. Use get_model_type to see what task is available.

See also: get_model_type, load_state_dict, load_config, HGFConfig

source
Transformers.HuggingFace.load_modelMethod
load_model(::Type{T}, config, state_dict = OrderedDict())

Create a new model of T according to config and state_dict. missing parameter would initialized according to config. Set the JULIA_DEBUG=Transformers environment variable to see what parameters are missing.

source
Transformers.HuggingFace.load_state_dictMethod

load_state_dict(model_name; local_files_only = false, cache = true)

Load the state_dict from the given model_name from huggingface hub. By default, this function would check if model_name exists on huggingface hub, download the model file (and cache it if cache is set), and then load and return the state_dict. If local_files_only = false, it would check whether the model file is up-to-date and update if not (and thus require network access every time it is called). By setting local_files_only = true, it would try to find the files from the cache directly and error out if not found. For managing the caches, see the HuggingFaceApi.jl package.

source
Transformers.HuggingFace.load_tokenizerFunction
load_tokenizer(model_name; config = nothing, local_files_only = false, cache = true)

Load the text encoder of model_name from huggingface hub. By default, this function would check if model_name exists on huggingface hub, download all required files for this text encoder (and cache these files if cache is set), and then load and return the text encoder. If local_files_only = false, it would check whether all cached files are up-to-date and update if not (and thus require network access every time it is called). By setting local_files_only = true, it would try to find the files from the cache directly and error out if not found. For managing the caches, see the HuggingFaceApi.jl package.

Example

julia> load_tokenizer("t5-small")
T5TextEncoder(
├─ TextTokenizer(MatchTokenization(PrecompiledNormalizer(WordReplaceNormalizer(UnigramTokenization(EachSplitTokenization(splitter = isspace), unigram = Unigram(vocab_size = 32100, unk = <unk>)), pattern = r"^(?!▁)(.*)$" => s"▁"), precompiled = PrecompiledNorm(...)), 103 patterns)),
├─ vocab = Vocab{String, SizedArray}(size = 32100, unk = <unk>, unki = 3),
├─ endsym = </s>,
├─ padsym = <pad>,
└─ process = Pipelines:
  ╰─ target[token] := TextEncodeBase.nestedcall(string_getvalue, source)
  ╰─ target[token] := Transformers.TextEncoders.grouping_sentence(target.token)
  ╰─ target[(token, segment)] := SequenceTemplate{String}(Input[1]:<type=1> </s>:<type=1> (Input[2]:<type=1> </s>:<type=1>)...)(target.token)
  ╰─ target[attention_mask] := (NeuralAttentionlib.LengthMask ∘ Transformers.TextEncoders.getlengths(nothing))(target.token)
  ╰─ target[token] := TextEncodeBase.trunc_and_pad(nothing, <pad>, tail, tail)(target.token)
  ╰─ target[token] := TextEncodeBase.nested2batch(target.token)
  ╰─ target := (target.token, target.attention_mask)
)
source
Transformers.HuggingFace.save_configMethod
save_config(model_name, config; path = pwd(), config_name = CONFIG_NAME, force = false)

Save the config at <path>/<model_name>/<config_name>. This would error out if the file already exists but force not set.

source
Transformers.HuggingFace.save_modelMethod

save_model(model_name, model; path = pwd(), weight_name = PYTORCH_WEIGHTS_NAME, force = false)

save the model statedict at `<path>/<modelname>/<weight_name>. This would error out if the file already exists butforce` not set.

source