MLlib + Automated MLflow Tracking(Python)
Loading...

Automated MLflow tracking in MLlib

MLflow provides automated tracking for model tuning with MLlib. With automated MLflow tracking, when you run tuning code using CrossValidator or TrainValidationSplit, the specified hyperparameters and evaluation metrics are automatically logged, making it easy to identify the optimal model.

This notebook shows an example of automated MLflow tracking with MLlib.

This notebook uses the PySpark classes DecisionTreeClassifier and CrossValidator to train and tune a model. MLflow automatically tracks the learning process, saving the results of each run so you can examine the hyperparameters to understand the impact of each one on the model's performance and find the optimal settings.

This notebook uses the MNIST handwritten digit recognition dataset, which is included with Databricks.

Load the training and test datasets

The dataset is already divided into training and test sets. Each dataset has two columns: an image, represented as a vector of 784 pixels, and a "label", or the actual number shown in the image.

The datasets are stored in the LIBSVM dataset format. Load them using the MLlib LIBSVM dataset reader utility.

training = spark.read.format("libsvm") \
  .option("numFeatures", "784") \
  .load("/databricks-datasets/mnist-digits/data-001/mnist-digits-train.txt")
test = spark.read.format("libsvm") \
  .option("numFeatures", "784") \
  .load("/databricks-datasets/mnist-digits/data-001/mnist-digits-test.txt")
 
training.cache()
test.cache()
 
print("There are {} training images and {} test images.".format(training.count(), test.count()))
There are 60000 training images and 10000 test images.

Display the data. Each image has the true label (the label column) and a vector of features that represent pixel intensities.

display(training)
 
label
features
1
2
3
4
5
6
7
8
9
10
11
5
[0, 784, [152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 260, 261, 262, 263, 264, 265, 266, 268, 269, 289, 290, 291, 292, 293, 319, 320, 321, 322, 347, 348, 349, 350, 376, 377, 378, 379, 380, 381, 405, 406, 407, 408, 409, 410, 434, 435, 436, 437, 438, 439, 463, 464, 465, 466, 467, 493, 494, 495, 496, 518, 519, 520, 521, 522, 523, 524, 544, 545, 546, 547, 548, 549, 550, 551, 570, 571, 572, 573, 574, 575, 576, 577, 578, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 676, 677, 678, 679, 680, 681, 682, 683], [3, 18, 18, 18, 126, 136, 175, 26, 166, 255, 247, 127, 30, 36, 94, 154, 170, 253, 253, 253, 253, 253, 225, 172, 253, 242, 195, 64, 49, 238, 253, 253, 253, 253, 253, 253, 253, 253, 251, 93, 82, 82, 56, 39, 18, 219, 253, 253, 253, 253, 253, 198, 182, 247, 241, 80, 156, 107, 253, 253, 205, 11, 43, 154, 14, 1, 154, 253, 90, 139, 253, 190, 2, 11, 190, 253, 70, 35, 241, 225, 160, 108, 1, 81, 240, 253, 253, 119, 25, 45, 186, 253, 253, 150, 27, 16, 93, 252, 253, 187, 249, 253, 249, 64, 46, 130, 183, 253, 253, 207, 2, 39, 148, 229, 253, 253, 253, 250, 182, 24, 114, 221, 253, 253, 253, 253, 201, 78, 23, 66, 213, 253, 253, 253, 253, 198, 81, 2, 18, 171, 219, 253, 253, 253, 253, 195, 80, 9, 55, 172, 226, 253, 253, 253, 253, 244, 133, 11, 136, 253, 253, 253, 212, 135, 132, 16]]
0
[0, 784, [127, 128, 129, 130, 131, 154, 155, 156, 157, 158, 159, 181, 182, 183, 184, 185, 186, 187, 188, 189, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 289, 290, 291, 292, 293, 294, 295, 296, 297, 300, 301, 302, 316, 317, 318, 319, 320, 321, 328, 329, 330, 343, 344, 345, 346, 347, 348, 349, 356, 357, 358, 371, 372, 373, 374, 384, 385, 386, 399, 400, 401, 412, 413, 414, 426, 427, 428, 429, 440, 441, 442, 454, 455, 456, 457, 466, 467, 468, 469, 470, 482, 483, 484, 493, 494, 495, 496, 497, 510, 511, 512, 520, 521, 522, 523, 538, 539, 540, 547, 548, 549, 550, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 622, 623, 624, 625, 626, 627, 628, 629, 630, 651, 652, 653, 654, 655, 656, 657], [51, 159, 253, 159, 50, 48, 238, 252, 252, 252, 237, 54, 227, 253, 252, 239, 233, 252, 57, 6, 10, 60, 224, 252, 253, 252, 202, 84, 252, 253, 122, 163, 252, 252, 252, 253, 252, 252, 96, 189, 253, 167, 51, 238, 253, 253, 190, 114, 253, 228, 47, 79, 255, 168, 48, 238, 252, 252, 179, 12, 75, 121, 21, 253, 243, 50, 38, 165, 253, 233, 208, 84, 253, 252, 165, 7, 178, 252, 240, 71, 19, 28, 253, 252, 195, 57, 252, 252, 63, 253, 252, 195, 198, 253, 190, 255, 253, 196, 76, 246, 252, 112, 253, 252, 148, 85, 252, 230, 25, 7, 135, 253, 186, 12, 85, 252, 223, 7, 131, 252, 225, 71, 85, 252, 145, 48, 165, 252, 173, 86, 253, 225, 114, 238, 253, 162, 85, 252, 249, 146, 48, 29, 85, 178, 225, 253, 223, 167, 56, 85, 252, 252, 252, 229, 215, 252, 252, 252, 196, 130, 28, 199, 252, 252, 253, 252, 252, 233, 145, 25, 128, 252, 253, 252, 141, 37]]
4
[0, 784, [160, 161, 162, 172, 173, 188, 189, 190, 200, 201, 215, 216, 217, 218, 228, 229, 243, 244, 245, 256, 257, 271, 272, 273, 283, 284, 285, 299, 300, 301, 311, 312, 313, 326, 327, 328, 329, 339, 340, 341, 354, 355, 356, 357, 367, 368, 369, 379, 380, 381, 382, 383, 384, 395, 396, 397, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 452, 453, 454, 455, 456, 457, 458, 459, 465, 466, 467, 493, 494, 495, 521, 522, 523, 549, 550, 551, 577, 578, 579, 605, 606, 607, 633, 634, 635, 661, 662, 663, 689, 690, 691], [67, 232, 39, 62, 81, 120, 180, 39, 126, 163, 2, 153, 210, 40, 220, 163, 27, 254, 162, 222, 163, 183, 254, 125, 46, 245, 163, 198, 254, 56, 120, 254, 163, 23, 231, 254, 29, 159, 254, 120, 163, 254, 216, 16, 159, 254, 67, 14, 86, 178, 248, 254, 91, 159, 254, 85, 47, 49, 116, 144, 150, 241, 243, 234, 179, 241, 252, 40, 150, 253, 237, 207, 207, 207, 253, 254, 250, 240, 198, 143, 91, 28, 5, 233, 250, 119, 177, 177, 177, 177, 177, 98, 56, 102, 254, 220, 169, 254, 137, 169, 254, 57, 169, 254, 57, 169, 255, 94, 169, 254, 96, 169, 254, 153, 169, 255, 153, 96, 254, 153]]
1
[0, 784, [158, 159, 160, 161, 185, 186, 187, 188, 189, 213, 214, 215, 216, 217, 240, 241, 242, 243, 244, 245, 267, 268, 269, 270, 271, 295, 296, 297, 298, 322, 323, 324, 325, 326, 349, 350, 351, 352, 353, 377, 378, 379, 380, 381, 404, 405, 406, 407, 408, 431, 432, 433, 434, 435, 459, 460, 461, 462, 463, 486, 487, 488, 489, 490, 514, 515, 516, 517, 518, 542, 543, 544, 545, 569, 570, 571, 572, 573, 596, 597, 598, 599, 600, 601, 624, 625, 626, 627, 652, 653, 654, 655, 680, 681, 682, 683], [124, 253, 255, 63, 96, 244, 251, 253, 62, 127, 251, 251, 253, 62, 68, 236, 251, 211, 31, 8, 60, 228, 251, 251, 94, 155, 253, 253, 189, 20, 253, 251, 235, 66, 32, 205, 253, 251, 126, 104, 251, 253, 184, 15, 80, 240, 251, 193, 23, 32, 253, 253, 253, 159, 151, 251, 251, 251, 39, 48, 221, 251, 251, 172, 234, 251, 251, 196, 12, 253, 251, 251, 89, 159, 255, 253, 253, 31, 48, 228, 253, 247, 140, 8, 64, 251, 253, 220, 64, 251, 253, 220, 24, 193, 253, 220]]
9
[0, 784, [208, 209, 210, 211, 212, 213, 214, 215, 216, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 289, 290, 291, 292, 293, 296, 297, 298, 299, 300, 316, 317, 318, 319, 320, 324, 325, 326, 327, 343, 344, 345, 346, 347, 350, 351, 352, 353, 354, 370, 371, 372, 373, 377, 378, 379, 380, 381, 382, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 489, 490, 491, 492, 517, 518, 519, 520, 546, 547, 548, 573, 574, 575, 576, 601, 602, 603, 604, 629, 630, 631, 632, 658, 659, 660, 686, 687, 688, 689, 714, 715, 716, 717, 718, 743, 744, 745, 746], [55, 148, 210, 253, 253, 113, 87, 148, 55, 87, 232, 252, 253, 189, 210, 252, 252, 253, 168, 4, 57, 242, 252, 190, 65, 5, 12, 182, 252, 253, 116, 96, 252, 252, 183, 14, 92, 252, 252, 225, 21, 132, 253, 252, 146, 14, 215, 252, 252, 79, 126, 253, 247, 176, 9, 8, 78, 245, 253, 129, 16, 232, 252, 176, 36, 201, 252, 252, 169, 11, 22, 252, 252, 30, 22, 119, 197, 241, 253, 252, 251, 77, 16, 231, 252, 253, 252, 252, 252, 226, 227, 252, 231, 55, 235, 253, 217, 138, 42, 24, 192, 252, 143, 62, 255, 253, 109, 71, 253, 252, 21, 253, 252, 21, 71, 253, 252, 21, 106, 253, 252, 21, 45, 255, 253, 21, 218, 252, 56, 96, 252, 189, 42, 14, 184, 252, 170, 11, 14, 147, 252, 42]]
2
[0, 784, [155, 156, 157, 158, 159, 181, 182, 183, 184, 185, 186, 187, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 261, 262, 263, 264, 265, 266, 267, 269, 270, 271, 272, 289, 290, 291, 292, 293, 294, 297, 298, 299, 300, 317, 318, 319, 320, 325, 326, 327, 328, 353, 354, 355, 356, 377, 378, 379, 380, 381, 382, 383, 384, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 455, 456, 457, 458, 459, 460, 462, 463, 464, 465, 466, 467, 468, 469, 482, 483, 484, 485, 486, 487, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 509, 510, 511, 512, 513, 514, 516, 517, 518, 519, 520, 522, 523, 524, 525, 526, 527, 528, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 553, 554, 555, 556, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 593, 594, 595, 596, 597, 598, 599, 600, 601, 621, 622, 623, 624, 625, 626], [13, 25, 100, 122, 7, 33, 151, 208, 252, 252, 252, 146, 40, 152, 244, 252, 253, 224, 211, 252, 232, 40, 15, 152, 239, 252, 252, 252, 216, 31, 37, 252, 252, 60, 96, 252, 252, 252, 252, 217, 29, 37, 252, 252, 60, 181, 252, 252, 220, 167, 30, 77, 252, 252, 60, 26, 128, 58, 22, 100, 252, 252, 60, 157, 252, 252, 60, 110, 121, 122, 121, 202, 252, 194, 3, 10, 53, 179, 253, 253, 255, 253, 253, 228, 35, 5, 54, 227, 252, 243, 228, 170, 242, 252, 252, 231, 117, 6, 6, 78, 252, 252, 125, 59, 18, 208, 252, 252, 252, 252, 87, 7, 5, 135, 252, 252, 180, 16, 21, 203, 253, 247, 129, 173, 252, 252, 184, 66, 49, 49, 3, 136, 252, 241, 106, 17, 53, 200, 252, 216, 65, 14, 72, 163, 241, 252, 252, 223, 105, 252, 242, 88, 18, 73, 170, 244, 252, 126, 29, 89, 180, 180, 37, 231, 252, 245, 205, 216, 252, 252, 252, 124, 3, 207, 252, 252, 252, 252, 178, 116, 36, 4, 13, 93, 143, 121, 23, 6]]
1
[0, 784, [124, 125, 126, 127, 151, 152, 153, 154, 155, 179, 180, 181, 182, 183, 208, 209, 210, 211, 235, 236, 237, 238, 239, 263, 264, 265, 266, 267, 268, 292, 293, 294, 295, 296, 321, 322, 323, 324, 349, 350, 351, 352, 377, 378, 379, 380, 405, 406, 407, 408, 433, 434, 435, 436, 461, 462, 463, 464, 489, 490, 491, 492, 493, 517, 518, 519, 520, 521, 545, 546, 547, 548, 549, 574, 575, 576, 577, 578, 602, 603, 604, 605, 606, 630, 631, 632, 633, 634, 658, 659, 660, 661, 662], [145, 255, 211, 31, 32, 237, 253, 252, 71, 11, 175, 253, 252, 71, 144, 253, 252, 71, 16, 191, 253, 252, 71, 26, 221, 253, 252, 124, 31, 125, 253, 252, 252, 108, 253, 252, 252, 108, 255, 253, 253, 108, 253, 252, 252, 108, 253, 252, 252, 108, 253, 252, 252, 108, 255, 253, 253, 170, 253, 252, 252, 252, 42, 149, 252, 252, 252, 144, 109, 252, 252, 252, 144, 218, 253, 253, 255, 35, 175, 252, 252, 253, 35, 73, 252, 252, 253, 35, 31, 211, 252, 253, 35]]
3
[0, 784, [151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 261, 262, 263, 264, 269, 270, 271, 272, 273, 274, 297, 298, 299, 300, 301, 324, 325, 326, 327, 328, 329, 350, 351, 352, 353, 354, 355, 356, 357, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 492, 493, 494, 495, 520, 521, 522, 523, 538, 539, 540, 547, 548, 549, 550, 551, 565, 566, 567, 568, 573, 574, 575, 576, 577, 578, 579, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 678, 679, 680, 681, 682, 683, 684], [38, 43, 105, 255, 253, 253, 253, 253, 253, 174, 6, 43, 139, 224, 226, 252, 253, 252, 252, 252, 252, 252, 252, 158, 14, 178, 252, 252, 252, 252, 253, 252, 252, 252, 252, 252, 252, 252, 59, 109, 252, 252, 230, 132, 133, 132, 132, 189, 252, 252, 252, 252, 59, 4, 29, 29, 24, 14, 226, 252, 252, 172, 7, 85, 243, 252, 252, 144, 88, 189, 252, 252, 252, 14, 91, 212, 247, 252, 252, 252, 204, 9, 32, 125, 193, 193, 193, 253, 252, 252, 252, 238, 102, 28, 45, 222, 252, 252, 252, 252, 253, 252, 252, 252, 177, 45, 223, 253, 253, 253, 253, 255, 253, 253, 253, 253, 74, 31, 123, 52, 44, 44, 44, 44, 143, 252, 252, 74, 15, 252, 252, 74, 86, 252, 252, 74, 5, 75, 9, 98, 242, 252, 252, 74, 61, 183, 252, 29, 18, 92, 239, 252, 252, 243, 65, 208, 252, 252, 147, 134, 134, 134, 134, 203, 253, 252, 252, 188, 83, 208, 252, 252, 252, 252, 252, 252, 252, 252, 253, 230, 153, 8, 49, 157, 252, 252, 252, 252, 252, 217, 207, 146, 45, 7, 103, 235, 252, 172, 103, 24]]
1
[0, 784, [152, 153, 154, 180, 181, 182, 183, 208, 209, 210, 211, 236, 237, 238, 239, 264, 265, 266, 267, 292, 293, 294, 295, 320, 321, 322, 323, 349, 350, 351, 377, 378, 379, 405, 406, 407, 433, 434, 435, 461, 462, 463, 489, 490, 491, 492, 517, 518, 519, 520, 546, 547, 548, 574, 575, 576, 602, 603, 604, 630, 631, 632, 658, 659, 660, 686, 687, 688], [5, 63, 197, 20, 254, 230, 24, 20, 254, 254, 48, 20, 254, 255, 48, 20, 254, 254, 57, 20, 254, 254, 108, 16, 239, 254, 143, 178, 254, 143, 178, 254, 143, 178, 254, 162, 178, 254, 240, 113, 254, 240, 83, 254, 245, 31, 79, 254, 246, 38, 214, 254, 150, 144, 241, 8, 144, 240, 2, 144, 254, 82, 230, 247, 40, 168, 209, 31]]
4
[0, 784, [134, 135, 161, 162, 163, 188, 189, 190, 191, 216, 217, 218, 236, 237, 238, 243, 244, 245, 246, 264, 265, 266, 271, 272, 273, 292, 293, 294, 298, 299, 300, 301, 319, 320, 321, 322, 325, 326, 327, 328, 329, 346, 347, 348, 349, 353, 354, 355, 373, 374, 375, 376, 380, 381, 382, 383, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 482, 483, 484, 488, 489, 490, 491, 492, 493, 494, 510, 511, 516, 517, 518, 519, 520, 521, 522, 543, 544, 545, 546, 571, 572, 573, 574, 598, 599, 600, 601, 626, 627, 628, 654, 655, 656], [189, 190, 143, 247, 153, 136, 247, 242, 86, 192, 252, 187, 62, 185, 18, 89, 236, 217, 47, 216, 253, 60, 212, 255, 81, 206, 252, 68, 48, 242, 253, 89, 131, 251, 212, 21, 11, 167, 252, 197, 5, 29, 232, 247, 63, 153, 252, 226, 45, 219, 252, 143, 116, 249, 252, 103, 4, 96, 253, 255, 253, 200, 122, 7, 25, 201, 250, 158, 92, 252, 252, 253, 217, 252, 252, 200, 227, 252, 231, 87, 251, 247, 231, 65, 48, 189, 252, 252, 253, 252, 251, 227, 35, 190, 221, 98, 42, 196, 252, 253, 252, 252, 162, 111, 29, 62, 239, 252, 86, 42, 42, 14, 15, 148, 253, 218, 121, 252, 231, 28, 31, 221, 251, 129, 218, 252, 160, 122, 252, 82]]
3
[0, 784, [123, 124, 125, 126, 127, 128, 129, 150, 151, 152, 153, 154, 155, 156, 157, 178, 179, 180, 181, 182, 183, 184, 185, 186, 207, 208, 209, 210, 211, 212, 213, 214, 236, 237, 238, 239, 240, 241, 242, 264, 265, 266, 267, 268, 269, 270, 293, 294, 295, 296, 297, 298, 320, 321, 322, 323, 324, 325, 326, 346, 347, 348, 349, 350, 351, 352, 353, 354, 374, 375, 376, 377, 378, 379, 380, 381, 382, 403, 404, 405, 406, 407, 408, 409, 410, 432, 433, 434, 435, 436, 437, 438, 463, 464, 465, 466, 467, 491, 492, 493, 494, 495, 519, 520, 521, 522, 538, 539, 540, 546, 547, 548, 549, 550, 566, 567, 568, 569, 570, 572, 573, 574, 575, 576, 577, 578, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 623, 624, 625, 626, 627, 628, 629, 630, 631, 652, 653, 654, 655, 656, 657, 658, 659], [42, 118, 219, 166, 118, 118, 6, 103, 242, 254, 254, 254, 254, 254, 66, 18, 232, 254, 254, 254, 254, 254, 238, 70, 104, 244, 254, 224, 254, 254, 254, 141, 207, 254, 210, 254, 254, 254, 34, 84, 206, 254, 254, 254, 254, 41, 24, 209, 254, 254, 254, 171, 91, 137, 253, 254, 254, 254, 112, 40, 214, 250, 254, 254, 254, 254, 254, 34, 81, 247, 254, 254, 254, 254, 254, 254, 146, 110, 246, 254, 254, 254, 254, 254, 171, 73, 89, 89, 93, 240, 254, 171, 1, 128, 254, 219, 31, 7, 254, 254, 214, 28, 138, 254, 254, 116, 19, 177, 90, 25, 240, 254, 254, 34, 164, 254, 215, 63, 36, 51, 89, 206, 254, 254, 139, 8, 57, 197, 254, 254, 222, 180, 241, 254, 254, 253, 213, 11, 140, 105, 254, 254, 254, 254, 254, 254, 236, 7, 117, 117, 165, 254, 254, 239, 50]]

Showing the first 715 rows.

Define the ML pipeline

In this example, as with most ML applications, you must do some preprocessing of the data before you can use the data to train a model. MLlib provides pipelines that allow you to combine multiple steps into a single workflow. In this example, you build a two-step pipeline:

  1. StringIndexer converts the labels from numeric values to categorical values.
  2. DecisionTreeClassifier trains a decision tree that can predict the "label" column based on the data in the "features" column.

For more information:
Pipelines

from pyspark.ml.classification import DecisionTreeClassifier, DecisionTreeClassificationModel
from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline
# StringIndexer: Convert the input column "label" (digits) to categorical values
indexer = StringIndexer(inputCol="label", outputCol="indexedLabel")
# DecisionTreeClassifier: Learn to predict column "indexedLabel" using the "features" column
dtc = DecisionTreeClassifier(labelCol="indexedLabel")
# Chain indexer + dtc together into a single ML Pipeline
pipeline = Pipeline(stages=[indexer, dtc])

Run the cross-validation

Now that you have defined the pipeline, you can run the cross validation to tune the model's hyperparameters. During this process, MLflow automatically tracks the models produced by CrossValidator, along with their evaluation metrics. This allows you to investigate how specific hyperparameters affect the model's performance.

In this example, you examine two hyperparameters in the cross-validation:

  • maxDepth. This parameter determines how deep, and thus how large, the tree can grow.
  • maxBins. For efficient distributed training of Decision Trees, MLlib discretizes (or "bins") continuous features into a finite number of values. The number of bins is controlled by maxBins. In this example, the number of bins corresponds to the number of grayscale levels; maxBins=2 turns the images into black and white images.

For more information:
maxBins
maxDepth

# Create an evaluator.  In this case, use "weightedPrecision".
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
evaluator = MulticlassClassificationEvaluator(labelCol="indexedLabel", metricName="weightedPrecision")
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
# Define the parameter grid to examine.
grid = ParamGridBuilder() \
  .addGrid(dtc.maxDepth, [2, 3, 4, 5, 6, 7, 8]) \
  .addGrid(dtc.maxBins, [2, 4, 8]) \
  .build()
# Create a cross validator, using the pipeline, evaluator, and parameter grid you created in previous steps.
cv = CrossValidator(estimator=pipeline, evaluator=evaluator, estimatorParamMaps=grid, numFolds=3)

Run CrossValidator. If an MLflow tracking server is available, CrossValidator automatically logs each run to MLflow, along with the evaluation metric calculated on the held-out data, under the current active run. If no run is active, a new one is created.

# Explicitly create a new run.
# This allows this cell to be run multiple times.
# If you omit mlflow.start_run(), then this cell could run once, but a second run would hit conflicts when attempting to overwrite the first run.
import mlflow
import mlflow.spark
 
with mlflow.start_run():
  # Run the cross validation on the training dataset. The cv.fit() call returns the best model it found.
  cvModel = cv.fit(training)
  
  # Evaluate the best model's performance on the test dataset and log the result.
  test_metric = evaluator.evaluate(cvModel.transform(test))
  mlflow.log_metric('test_' + evaluator.getMetricName(), test_metric) 
  
  # Log the best model.
  mlflow.spark.log_model(spark_model=cvModel.bestModel, artifact_path='best-model') 
MLlib will automatically track trials in MLflow. After your tuning fit() call has completed, view the MLflow UI to see logged runs.

Review the logged results

To view the MLflow experiment associated with the notebook, click the Experiment icon in the notebook context bar on the upper right. All notebook runs appear in the sidebar. To more easily compare their results, click the icon at the far right of Experiment Runs (it shows "View Experiment UI" when you hover over it). The Experiment page appears.

For example, to examine the effect of tuning maxDepth:

  1. On the Experiment page, enter params.maxBins = "8" in the Search Runs box, and click Search.
  2. Select the resulting runs and click Compare.
  3. In the Scatter Plot, select X-axis maxDepth and Y-axis avg_weightedPrecision.

You can see that, when maxBins is held constant at 8, the average weighted precision increases with maxDepth.