tune-classification-model-hugging-face-transformers(Python)

Loading...

Tune a text classification model with Hugging Face Transformers

This notebook trains a SMS spam classifier with "distillibert-base-uncased" as the base model on a single GPU machine using the 🤗 Transformers library.

It first downloads a small dataset, copies it to DBFS, then converts it to a Spark DataFrame. Preprocessing up to tokenization is done on Spark. While DBFS is used as a convenience to access the datasets directly as local files on the driver, you can modify it to avoid use of DBFS.

Text tokenization of the SMS messages is done in transformers in the model's default tokenizer in order to have consistency in tokenization with the base model. The notebook uses the Trainer utility in the transformers library to fine-tune the model. The notebook wraps the tokenizer and trained model in a Transformers pipeline and logs the pipeline as an MLflow model. This make it easy to directly apply the pipeline as a UDF on Spark DataFrame string columns.

Cluster setup

For this notebook, Databricks recommends a single GPU cluster, such as a g4dn.xlarge on AWS or Standard_NC4as_T4_v3 on Azure. You can create a single machine cluster using the personal compute policy or by choosing "Single Node" when creating a cluster. This notebook works with Databricks Runtime ML GPU version 11.1 or greater. Databricks Runtime ML GPU versions 9.1 through 10.4 can be used by replacing the following command with %pip install --upgrade transformers datasets evaluate.

The transformers library is installed by default on Databricks Runtime ML. This notebook also requires 🤗 Datasets and 🤗 Evaluate, which you can install using %pip.

Python interpreter will be restarted. Requirement already satisfied: datasets in /local_disk0/.ephemeral_nfs/envs/pythonEnv-9aa40412-3a68-404c-8c62-439753d135d3/lib/python3.9/site-packages (2.8.0) Requirement already satisfied: evaluate in /local_disk0/.ephemeral_nfs/envs/pythonEnv-9aa40412-3a68-404c-8c62-439753d135d3/lib/python3.9/site-packages (0.4.0) Requirement already satisfied: dill<0.3.7 in /local_disk0/.ephemeral_nfs/envs/pythonEnv-9aa40412-3a68-404c-8c62-439753d135d3/lib/python3.9/site-packages (from datasets) (0.3.6) Requirement already satisfied: aiohttp in /local_disk0/.ephemeral_nfs/envs/pythonEnv-9aa40412-3a68-404c-8c62-439753d135d3/lib/python3.9/site-packages (from datasets) (3.8.3) Requirement already satisfied: pyyaml>=5.1 in /databricks/python3/lib/python3.9/site-packages (from datasets) (6.0) Requirement already satisfied: responses<0.19 in /local_disk0/.ephemeral_nfs/envs/pythonEnv-9aa40412-3a68-404c-8c62-439753d135d3/lib/python3.9/site-packages (from datasets) (0.18.0) Requirement already satisfied: numpy>=1.17 in /databricks/python3/lib/python3.9/site-packages (from datasets) (1.21.5) Requirement already satisfied: pandas in /databricks/python3/lib/python3.9/site-packages (from datasets) (1.4.2) Requirement already satisfied: huggingface-hub<1.0.0,>=0.2.0 in /databricks/python3/lib/python3.9/site-packages (from datasets) (0.11.0) Requirement already satisfied: pyarrow>=6.0.0 in /databricks/python3/lib/python3.9/site-packages (from datasets) (7.0.0) Requirement already satisfied: packaging in /databricks/python3/lib/python3.9/site-packages (from datasets) (21.3) Requirement already satisfied: requests>=2.19.0 in /databricks/python3/lib/python3.9/site-packages (from datasets) (2.27.1) Requirement already satisfied: tqdm>=4.62.1 in /databricks/python3/lib/python3.9/site-packages (from datasets) (4.64.0) Requirement already satisfied: fsspec[http]>=2021.11.1 in /databricks/python3/lib/python3.9/site-packages (from datasets) (2022.2.0) Requirement already satisfied: multiprocess in /local_disk0/.ephemeral_nfs/envs/pythonEnv-9aa40412-3a68-404c-8c62-439753d135d3/lib/python3.9/site-packages (from datasets) (0.70.14) Requirement already satisfied: xxhash in /local_disk0/.ephemeral_nfs/envs/pythonEnv-9aa40412-3a68-404c-8c62-439753d135d3/lib/python3.9/site-packages (from datasets) (3.2.0) Requirement already satisfied: typing-extensions>=3.7.4.3 in /databricks/python3/lib/python3.9/site-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (4.1.1) Requirement already satisfied: filelock in /databricks/python3/lib/python3.9/site-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (3.6.0) Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /databricks/python3/lib/python3.9/site-packages (from packaging->datasets) (3.0.4) Requirement already satisfied: idna<4,>=2.5 in /databricks/python3/lib/python3.9/site-packages (from requests>=2.19.0->datasets) (3.3) Requirement already satisfied: charset-normalizer~=2.0.0 in /databricks/python3/lib/python3.9/site-packages (from requests>=2.19.0->datasets) (2.0.4) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /databricks/python3/lib/python3.9/site-packages (from requests>=2.19.0->datasets) (1.26.9) Requirement already satisfied: certifi>=2017.4.17 in /databricks/python3/lib/python3.9/site-packages (from requests>=2.19.0->datasets) (2021.10.8) Requirement already satisfied: frozenlist>=1.1.1 in /local_disk0/.ephemeral_nfs/envs/pythonEnv-9aa40412-3a68-404c-8c62-439753d135d3/lib/python3.9/site-packages (from aiohttp->datasets) (1.3.3) Requirement already satisfied: yarl<2.0,>=1.0 in /local_disk0/.ephemeral_nfs/envs/pythonEnv-9aa40412-3a68-404c-8c62-439753d135d3/lib/python3.9/site-packages (from aiohttp->datasets) (1.8.2) Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /local_disk0/.ephemeral_nfs/envs/pythonEnv-9aa40412-3a68-404c-8c62-439753d135d3/lib/python3.9/site-packages (from aiohttp->datasets) (4.0.2) Requirement already satisfied: aiosignal>=1.1.2 in /local_disk0/.ephemeral_nfs/envs/pythonEnv-9aa40412-3a68-404c-8c62-439753d135d3/lib/python3.9/site-packages (from aiohttp->datasets) (1.3.1) Requirement already satisfied: multidict<7.0,>=4.5 in /local_disk0/.ephemeral_nfs/envs/pythonEnv-9aa40412-3a68-404c-8c62-439753d135d3/lib/python3.9/site-packages (from aiohttp->datasets) (6.0.4) Requirement already satisfied: attrs>=17.3.0 in /databricks/python3/lib/python3.9/site-packages (from aiohttp->datasets) (21.4.0) Requirement already satisfied: python-dateutil>=2.8.1 in /databricks/python3/lib/python3.9/site-packages (from pandas->datasets) (2.8.2) Requirement already satisfied: pytz>=2020.1 in /databricks/python3/lib/python3.9/site-packages (from pandas->datasets) (2021.3) Requirement already satisfied: six>=1.5 in /databricks/python3/lib/python3.9/site-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0) Python interpreter will be restarted.

Set up any parameters for the notebook.

  • The base model DistilBERT base model (uncased) is a great foundational model that is smaller and faster than BERT base model (uncased), but still provides similar behavior. This notebook fine tunes this base model.
  • The tutorial_path sets the path in DBFS that the notebook uses to write the sample dataset. It is deleted by the last command in this notebook.

Data download and loading

Start by downloading the dataset and load it into a Spark DataFrame. The SMS Spam Collection Dataset is available from the UCI Machine Learning Repository.

--2023-01-25 23:28:07-- https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip Resolving archive.ics.uci.edu (archive.ics.uci.edu)... 128.195.10.252 Connecting to archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 203415 (199K) [application/x-httpd-php] Saving to: ‘smsspamcollection.zip.4’ 0K .......... .......... .......... .......... .......... 25% 864K 0s 50K .......... .......... .......... .......... .......... 50% 1.73M 0s 100K .......... .......... .......... .......... .......... 75% 1.74M 0s 150K .......... .......... .......... .......... ........ 100% 1.67M=0.1s 2023-01-25 23:28:07 (1.36 MB/s) - ‘smsspamcollection.zip.4’ saved [203415/203415]

Unzip the downloaded archive.

Archive: smsspamcollection.zip inflating: SMSSpamCollection inflating: readme

Copy the file to DBFS.

Out[4]: True

Load the dataset into a DataFrame. The file is tab separated and does not contain a header, so we specify the separator using sep and specify the column names explicitly.

 
label
text
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
ham
Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...
ham
Ok lar... Joking wif u oni...
spam
Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
ham
U dun say so early hor... U c already then say...
ham
Nah I don't think he goes to usf, he lives around here though
spam
FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv
ham
Even my brother is not like to speak with me. They treat me like aids patent.
ham
As per your request 'Melle Melle (Oru Minnaminunginte Nurungu Vettam)' has been set as your callertune for all Callers. Press *9 to copy your friends Callertune
spam
WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.
spam
Had your mobile 11 months or more? U R entitled to Update to the latest colour mobiles with camera for Free! Call The Mobile Update Co FREE on 08002986030
ham
I'm gonna be home soon and i don't want to talk about this stuff anymore tonight, k? I've cried enough today.
spam
SIX chances to win CASH! From 100 to 20,000 pounds txt> CSH11 and send to 87575. Cost 150p/day, 6days, 16+ TsandCs apply Reply HL 4 info
spam
URGENT! You have won a 1 week FREE membership in our £100,000 Prize Jackpot! Txt the word: CLAIM to No: 81010 T&C www.dbuk.net LCCLTD POBOX 4403LDNW1A7RW18
ham
I've been searching for the right words to thank you for this breather. I promise i wont take your help for granted and will fulfil my promise. You have been wonderful and a blessing at all times.
ham
I HAVE A DATE ON SUNDAY WITH WILL!!
spam
XXXMobileMovieClub: To use your credit, click the WAP link in the next txt message or click here>> http://wap. xxxmobilemovieclub.com?n=QJKGIGHJJGCBL
ham
Oh k...i'm watching here:)
1,000 rows|Truncated data
Out[5]: 5574

Data preparation

The datasets passed into the transformers trainer for text classification need to have integer labels [0, 1].

Collect the labels and generate a mapping from labels to IDs and vice versa. transformers models need these mappings to correctly translate the integer values into the human readable labels.

Replace the string labels with the IDs in the DataFrame.

 
label
text
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
0
Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...
0
Ok lar... Joking wif u oni...
1
Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
0
U dun say so early hor... U c already then say...
0
Nah I don't think he goes to usf, he lives around here though
1
FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv
0
Even my brother is not like to speak with me. They treat me like aids patent.
0
As per your request 'Melle Melle (Oru Minnaminunginte Nurungu Vettam)' has been set as your callertune for all Callers. Press *9 to copy your friends Callertune
1
WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only.
1
Had your mobile 11 months or more? U R entitled to Update to the latest colour mobiles with camera for Free! Call The Mobile Update Co FREE on 08002986030
0
I'm gonna be home soon and i don't want to talk about this stuff anymore tonight, k? I've cried enough today.
1
SIX chances to win CASH! From 100 to 20,000 pounds txt> CSH11 and send to 87575. Cost 150p/day, 6days, 16+ TsandCs apply Reply HL 4 info
1
URGENT! You have won a 1 week FREE membership in our £100,000 Prize Jackpot! Txt the word: CLAIM to No: 81010 T&C www.dbuk.net LCCLTD POBOX 4403LDNW1A7RW18
0
I've been searching for the right words to thank you for this breather. I promise i wont take your help for granted and will fulfil my promise. You have been wonderful and a blessing at all times.
0
I HAVE A DATE ON SUNDAY WITH WILL!!
1
XXXMobileMovieClub: To use your credit, click the WAP link in the next txt message or click here>> http://wap. xxxmobilemovieclub.com?n=QJKGIGHJJGCBL
0
Oh k...i'm watching here:)
1,000 rows|Truncated data

Hugging Face datasets supports loading from Spark DataFrames using datasets.Dataset.from_spark. See the Hugging Face documentation to learn more about the from_spark() method.

Dataset.from_spark caches the dataset. In this example, the model is trained on the driver, and the cached data is parallelized using Spark, so cache_dir must be accessible to the driver and to all the workers. You can use the Databricks File System (DBFS) root(AWS| Azure |GCP) or mount point (AWS | Azure | GCP).

By using DBFS, you can reference "local" paths when creating the transformers compatible datasets used for model training.

Tokenize and shuffle the datasets for training. Since the Trainer does not need the untokenized text columns for training, the notebook removes them from the dataset. This isn't necessary, but not removing the column results in a warning during training. In this step, datasets also caches the transformed datasets on local disk for fast subsequent loading during model training.

0%| | 0/5 [00:00<?, ?ba/s]
0%| | 0/2 [00:00<?, ?ba/s]

Model training

For model training, this notebook largely uses default behavior. However, you can use the full range of metrics and parameters available to the Trainer to adjust your model training behavior.

Create the evaluation metric to log. Loss is also logged, but adding other metrics such as accuracy can make modeling performance easier to understand.

Construct default training arguments. This is where you would set many of your training parameters, such as the learning rate. Refer to transformers documentation for the full range of arguments you can set.

Create the model to train from the base model, specifying the label mappings and the number of classes.

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.bias'] - This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'classifier.bias', 'pre_classifier.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Using a data collator batches input in training and evaluation datasets. Using the DataCollatorWithPadding with defaults gives good baseline performance for text classification.

Construct the trainer object with the model, arguments, datasets, collator, and metrics created above.

Construct the MLflow wrapper class to store the model as a pipeline. When loading the pipeline, this model uses the GPU if CUDA is available. This model hardwires the batchsize to use with the transformers pipeline. You'll want to set this with the hardware you will use for inference in mind.

Train the model, logging metrics and results to MLflow. This task is very easy for BERT-based models. Don't be surprised is the evaluation accuracy is 1 or close to 1.

/databricks/python/lib/python3.9/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning warnings.warn( ***** Running training ***** Num examples = 4426 Num Epochs = 3 Instantaneous batch size per device = 8 Total train batch size (w. parallel, distributed & accumulation) = 8 Gradient Accumulation steps = 1 Total optimization steps = 1662 You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Saving model checkpoint to sms_trainer/checkpoint-500 Configuration saved in sms_trainer/checkpoint-500/config.json Model weights saved in sms_trainer/checkpoint-500/pytorch_model.bin ***** Running Evaluation ***** Num examples = 1148 Batch size = 8 Saving model checkpoint to sms_trainer/checkpoint-1000 Configuration saved in sms_trainer/checkpoint-1000/config.json Model weights saved in sms_trainer/checkpoint-1000/pytorch_model.bin ***** Running Evaluation ***** Num examples = 1148 Batch size = 8 Saving model checkpoint to sms_trainer/checkpoint-1500 Configuration saved in sms_trainer/checkpoint-1500/config.json Model weights saved in sms_trainer/checkpoint-1500/pytorch_model.bin ***** Running Evaluation ***** Num examples = 1148 Batch size = 8 Training completed. Do not forget to share your model on huggingface.co/models =) Saving model checkpoint to ./sms_model Configuration saved in ./sms_model/config.json Model weights saved in ./sms_model/pytorch_model.bin loading configuration file ./sms_model/config.json Model config DistilBertConfig { "_name_or_path": "./sms_model", "activation": "gelu", "architectures": [ "DistilBertForSequenceClassification" ], "attention_dropout": 0.1, "dim": 768, "dropout": 0.1, "hidden_dim": 3072, "id2label": { "0": "ham", "1": "spam" }, "initializer_range": 0.02, "label2id": { "ham": 0, "spam": 1 }, "max_position_embeddings": 512, "model_type": "distilbert", "n_heads": 12, "n_layers": 6, "pad_token_id": 0, "problem_type": "single_label_classification", "qa_dropout": 0.1, "seq_classif_dropout": 0.2, "sinusoidal_pos_embds": false, "tie_weights_": true, "torch_dtype": "float32", "transformers_version": "4.23.1", "vocab_size": 30522 } loading weights file ./sms_model/pytorch_model.bin All model checkpoint weights were used when initializing DistilBertForSequenceClassification. All the weights of DistilBertForSequenceClassification were initialized from the model checkpoint at ./sms_model. If your task is similar to the task the model of the checkpoint was trained on, you can already use DistilBertForSequenceClassification for predictions without further training. Configuration saved in ./sms_pipeline/config.json Model weights saved in ./sms_pipeline/pytorch_model.bin tokenizer config file saved in ./sms_pipeline/tokenizer_config.json Special tokens file saved in ./sms_pipeline/special_tokens_map.json 2023/01/25 23:31:10 WARNING mlflow.utils.requirements_utils: Found torch version (1.12.1+cu113) contains a local version label (+cu113). MLflow logged a pip requirement for this package as 'torch==1.12.1' without the local version label to make it installable from PyPI. To specify pip requirements containing local version labels, please use `conda_env` or `pip_requirements`. /databricks/python/lib/python3.9/site-packages/_distutils_hack/__init__.py:30: UserWarning: Setuptools is replacing distutils. warnings.warn("Setuptools is replacing distutils.")

Batch inference

Load the model as a UDF using MLflow and use it for batch scoring.

2023/01/25 23:31:16 WARNING mlflow.pyfunc: Calling `spark_udf()` with `env_manager="local"` does not recreate the same environment that was used during training, which may lead to errors or inaccurate predictions. We recommend specifying `env_manager="conda"`, which automatically recreates the environment that was used to train the model and performs inference in the recreated environment. 2023/01/25 23:31:16 INFO mlflow.models.flavor_backend_registry: Selected backend for flavor 'python_function'
 
text
label
prediction
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"Happy valentines day" I know its early but i have hundreds of handsomes and beauties to wish. So i thought to finish off aunties and uncles 1st...
0
ham
"Response" is one of d powerful weapon 2 occupy a place in others 'HEART'... So, always give response 2 who cares 4 U"... Gud night..swt dreams..take care
0
ham
"Wen u miss someone, the person is definitely special for u..... But if the person is so special, why to miss them, just Keep-in-touch" gdeve..
0
ham
&lt;#&gt;  w jetton ave if you forgot
0
ham
&lt;#&gt; ISH MINUTES WAS 5 MINUTES AGO. WTF.
0
ham
'Wnevr i wana fal in luv vth my books, My bed fals in luv vth me..!'' . Yen madodu, nav pretsorginta, nammanna pretsovru important alwa....!!:) Gud eveB-).
0
ham
(No promises on when though, haven't even gotten dinner yet)
0
ham
* Am on my way
0
ham
, how's things? Just a quick question.
0
ham
... Are you in the pub?
0
ham
....photoshop makes my computer shut down.
0
ham
1's finish meeting call me.
0
ham
1's reach home call me.
0
ham
3 pa but not selected.
0
ham
7 wonders in My WORLD 7th You 6th Ur style 5th Ur smile 4th Ur Personality 3rd Ur Nature 2nd Ur SMS and 1st "Ur Lovely Friendship"... good morning dear
0
ham
A Boy loved a gal. He propsd bt she didnt mind. He gv lv lttrs, Bt her frnds threw thm. Again d boy decided 2 aproach d gal , dt time a truck was speeding towards d gal. Wn it was about 2 hit d girl,d boy ran like hell n saved her. She asked 'hw cn u run so fast?' D boy replied "Boost is d secret of my energy" n instantly d girl shouted "our energy" n Thy lived happily 2gthr drinking boost evrydy Moral of d story:- I hv free msgs:D;): gud ni8
0
ham
1,000 rows|Truncated data

Cleanup

Remove the files placed in DBFS.

Out[19]: True