fit_transform() and transform() are methods used in machine learning, particularly when working with data preprocessing steps like scaling or encoding.
-
fit_transform():
- Purpose: This method combines two steps: fitting the preprocessing transformation (learning the parameters from the data) and applying it (transforming the data based on those parameters).
- When to use: You use
fit_transform()on your training data. It learns and applies the transformation to your data in one go. - Example: If you're scaling your data,
fit_transform()will calculate the mean and standard deviation from your training data and then scale the data accordingly.
-
transform():
- Purpose: This method applies a transformation that was previously learned using
fit_transform()orfit(). - When to use: You use
transform()on your test data (or any new data). It applies the transformation using parameters learned from the training data. - Example: After using
fit_transform()on your training data to scale it, you usetransform()on your test data to scale it using the same mean and standard deviation calculated from the training data.
- Purpose: This method applies a transformation that was previously learned using
In simple terms:
- Fit: Learn something from the data.
- Transform: Apply what you've learned.
Key takeaway: fit_transform() does both learning from the data and applying the transformation, whereas transform() only applies a transformation that was already learned. Use fit_transform() on training data and transform() on test or new data to ensure consistency in how your data is processed.
Let's illustrate fit_transform() and transform() with a practical example using data scaling, which is a common preprocessing step in machine learning.
Suppose we have a dataset with numerical features that we want to scale using standardization (subtracting the mean and dividing by the standard deviation).
import numpy as np
from sklearn.preprocessing import StandardScaler
# Dummy dataset (imagine these are your features)
data = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# Step 1: Initialize the scaler
scaler = StandardScaler()
# Step 2: Use fit_transform() on training data
scaled_data = scaler.fit_transform(data)
print("Scaled Data (after fit_transform):")
print(scaled_data)In this example:
-
Initialization (
StandardScaler()): We create aStandardScalerobject. This object will later be used to scale our data. -
fit_transform():
- We apply
fit_transform(data)on our datasetdata. fit_transform()learns the mean and standard deviation fromdataand then transformsdatabased on these statistics.- After
fit_transform(),scalernow holds the mean and standard deviation learned fromdata, andscaled_datacontains the scaled values.
- We apply
-
Output:
scaled_datanow contains the scaled version ofdata, where each column (feature) has a mean of 0 and a standard deviation of 1.
Now, let's simulate a scenario where we have new data (test data) that we want to scale using the same scaling parameters learned from our training data.
# New test data (simulating new data)
new_data = np.array([[10, 11, 12]])
# Step 3: Use transform() on test data
scaled_new_data = scaler.transform(new_data)
print("\nNew Data Scaled (using transform):")
print(scaled_new_data)In this part:
-
transform():
- We use
transform(new_data)on our new datanew_data. transform()applies the scaling using the mean and standard deviation learned from the training data (data).scaled_new_datanow contains the scaled version ofnew_data, using the same scaling parameters (mean and standard deviation) asscaled_data.
- We use
-
Output:
scaled_new_datashows hownew_datawould look like after scaling using the same scaling parameters as the training data.
- fit_transform(): Learns the parameters (like mean and standard deviation) from the data and then transforms the data based on those parameters. Typically used on training data.
- transform(): Applies the transformation using parameters learned from
fit_transform(). Used on test data (or any new data) to ensure consistency with the scaling applied to the training data.
These methods (fit_transform() and transform()) are crucial for maintaining consistency in preprocessing steps across different subsets of data in machine learning workflows.