Skip to content

Commit ed8e3cb

Browse files
committed
feat: configuration as catalog
1 parent 77100c0 commit ed8e3cb

67 files changed

Lines changed: 1764 additions & 2231 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using Flowthru.Core.Data;
2+
using KedroSpaceflights.Custom.Flows.DataEvaluation.Steps;
3+
using KedroSpaceflights.Custom.Flows.DataScience.Steps;
4+
5+
namespace KedroSpaceflights.Custom.Data;
6+
7+
/// <summary>
8+
/// Configuration catalog for the KedroSpaceflights.Custom pipeline.
9+
/// Properties are bound from appsettings.json via the source-generated constructor.
10+
/// </summary>
11+
[FlowthruConfig]
12+
public partial class FlowConfig
13+
{
14+
/// <summary>Configuration options for the train/test split step.</summary>
15+
[ConfigSection("Flowthru:Flows:DataScience")]
16+
public IItem<CreateTestTrainSplitStep.TestTrainSplitParams> ModelParams { get; }
17+
18+
/// <summary>Configuration options for cross-validation.</summary>
19+
[ConfigSection("Flowthru:Flows:DataEvaluation")]
20+
public IItem<CrossValidateModelStep.Params> CrossValidationParams { get; }
21+
}

examples/advanced/KedroSpaceflights.Custom/Flows/DataEvaluation/DataEvaluationFlow.cs

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,7 @@ namespace KedroSpaceflights.Custom.Flows.DataEvaluation;
3232
/// </summary>
3333
public static class DataEvaluationFlow
3434
{
35-
/// <summary>
36-
/// Parameters for the data evaluation pipeline.
37-
/// </summary>
38-
public class Params
39-
{
40-
/// <summary>
41-
/// Options for cross-validation.
42-
/// </summary>
43-
public CrossValidateModelStep.Params CrossValidationParams { get; init; } = new();
44-
}
45-
46-
public static Flow Create(Catalog catalog, Params parameters)
35+
public static Flow Create(Catalog catalog, FlowConfig config)
4736
{
4837
return FlowBuilder.CreateFlow(pipeline =>
4938
{
@@ -58,8 +47,8 @@ public static Flow Create(Catalog catalog, Params parameters)
5847
// Step 2: Cross-validation for R² distribution analysis and comparison to Kedro
5948
pipeline.AddStep(
6049
label: "PerformCrossValidatedOLSRegressionTest",
61-
transform: CrossValidateModelStep.Create(parameters.CrossValidationParams),
62-
input: catalog.ModelInputTable,
50+
transform: CrossValidateModelStep.Create,
51+
input: (catalog.ModelInputTable, config.CrossValidationParams),
6352
output: catalog.CrossValidationResults
6453
);
6554
});

examples/advanced/KedroSpaceflights.Custom/Flows/DataEvaluation/Steps/CrossValidateModelStep.cs

Lines changed: 97 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
using Flowthru.Core.Steps;
33
using KedroSpaceflights.Custom.Data._03_Primary.Schemas;
44
using KedroSpaceflights.Custom.Data._06_Reporting.Schemas;
5-
using Microsoft.Extensions.Logging;
65
using Microsoft.ML;
76

87
namespace KedroSpaceflights.Custom.Flows.DataEvaluation.Steps;
@@ -43,138 +42,113 @@ public record Params
4342
public float KedroReferenceR2Score { get; init; }
4443
}
4544

46-
public static Func<IEnumerable<ModelInputSchema>, Task<CrossValidationResults>> Create(
47-
Params? parameters = null,
48-
ILogger? logger = null
45+
public static async Task<CrossValidationResults> Create(
46+
(IEnumerable<ModelInputSchema> Data, Params Options) input
4947
)
5048
{
51-
var config = parameters ?? new Params();
49+
var (rawInput, config) = input;
50+
var data = rawInput.ToList();
5251

53-
return async (input) =>
54-
{
55-
var data = input.ToList();
56-
logger?.LogInformation("Starting cross-validation with {Folds} folds", config.NumFolds);
57-
58-
// Convert to feature rows
59-
var featureRows = data.Select(row => new FeatureRow
60-
{
61-
Engines = (float)row.Engines,
62-
PassengerCapacity = (float)row.PassengerCapacity,
63-
Crew = (float)row.Crew,
64-
DCheckComplete = row.DCheckComplete,
65-
IataApproved = row.IataApproved,
66-
CompanyRating = (float)row.CompanyRating,
67-
ReviewScoresRating = (float)row.ReviewScoresRating,
68-
Price = (float)row.Price,
69-
})
70-
.ToList();
71-
72-
var mlContext = new MLContext(seed: config.BaseSeed);
73-
var allData = mlContext.Data.LoadFromEnumerable(featureRows);
74-
75-
// Define the ML pipeline (same as TrainModelStep)
76-
var pipeline = mlContext
77-
.Transforms.CopyColumns(
78-
outputColumnName: "Label",
79-
inputColumnName: nameof(FeatureRow.Price)
52+
// Convert to feature rows
53+
var featureRows = data.Select(row => new FeatureRow
54+
{
55+
Engines = (float)row.Engines,
56+
PassengerCapacity = (float)row.PassengerCapacity,
57+
Crew = (float)row.Crew,
58+
DCheckComplete = row.DCheckComplete,
59+
IataApproved = row.IataApproved,
60+
CompanyRating = (float)row.CompanyRating,
61+
ReviewScoresRating = (float)row.ReviewScoresRating,
62+
Price = (float)row.Price,
63+
})
64+
.ToList();
65+
66+
var mlContext = new MLContext(seed: config.BaseSeed);
67+
var allData = mlContext.Data.LoadFromEnumerable(featureRows);
68+
69+
// Define the ML pipeline (same as TrainModelStep)
70+
var pipeline = mlContext
71+
.Transforms.CopyColumns(outputColumnName: "Label", inputColumnName: nameof(FeatureRow.Price))
72+
.Append(
73+
mlContext.Transforms.Categorical.OneHotEncoding(
74+
outputColumnName: "DCheckCompleteEncoded",
75+
inputColumnName: nameof(FeatureRow.DCheckComplete)
8076
)
81-
.Append(
82-
mlContext.Transforms.Categorical.OneHotEncoding(
83-
outputColumnName: "DCheckCompleteEncoded",
84-
inputColumnName: nameof(FeatureRow.DCheckComplete)
85-
)
77+
)
78+
.Append(
79+
mlContext.Transforms.Categorical.OneHotEncoding(
80+
outputColumnName: "IataApprovedEncoded",
81+
inputColumnName: nameof(FeatureRow.IataApproved)
8682
)
87-
.Append(
88-
mlContext.Transforms.Categorical.OneHotEncoding(
89-
outputColumnName: "IataApprovedEncoded",
90-
inputColumnName: nameof(FeatureRow.IataApproved)
91-
)
83+
)
84+
.Append(
85+
mlContext.Transforms.Concatenate(
86+
"Features",
87+
nameof(FeatureRow.Engines),
88+
nameof(FeatureRow.PassengerCapacity),
89+
nameof(FeatureRow.Crew),
90+
"DCheckCompleteEncoded",
91+
"IataApprovedEncoded",
92+
nameof(FeatureRow.CompanyRating),
93+
nameof(FeatureRow.ReviewScoresRating)
9294
)
93-
.Append(
94-
mlContext.Transforms.Concatenate(
95-
"Features",
96-
nameof(FeatureRow.Engines),
97-
nameof(FeatureRow.PassengerCapacity),
98-
nameof(FeatureRow.Crew),
99-
"DCheckCompleteEncoded",
100-
"IataApprovedEncoded",
101-
nameof(FeatureRow.CompanyRating),
102-
nameof(FeatureRow.ReviewScoresRating)
103-
)
95+
)
96+
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
97+
.Append(
98+
mlContext.Regression.Trainers.OnlineGradientDescent(
99+
labelColumnName: "Label",
100+
featureColumnName: "Features",
101+
numberOfIterations: 1000
104102
)
105-
.Append(mlContext.Transforms.NormalizeMinMax("Features"))
106-
.Append(
107-
mlContext.Regression.Trainers.OnlineGradientDescent(
108-
labelColumnName: "Label",
109-
featureColumnName: "Features",
110-
numberOfIterations: 1000
111-
)
112-
);
113-
114-
// Perform cross-validation
115-
var cvResults = mlContext.Regression.CrossValidate(
116-
allData,
117-
pipeline,
118-
numberOfFolds: config.NumFolds,
119-
labelColumnName: "Label"
120103
);
121104

122-
// Extract metrics from each fold
123-
var foldMetrics = cvResults
124-
.Select(
125-
(result, index) =>
126-
new FoldMetric
127-
{
128-
FoldNumber = index + 1,
129-
R2Score = result.Metrics.RSquared,
130-
MeanAbsoluteError = result.Metrics.MeanAbsoluteError,
131-
RootMeanSquaredError = result.Metrics.RootMeanSquaredError,
132-
LossFunctionValue = result.Metrics.LossFunction,
133-
}
134-
)
135-
.ToList();
136-
137-
// Calculate statistics
138-
var r2Scores = foldMetrics.Select(f => f.R2Score).ToList();
139-
var meanR2 = r2Scores.Average();
140-
var stdDevR2 = Math.Sqrt(r2Scores.Select(x => Math.Pow(x - meanR2, 2)).Average());
141-
var minR2 = r2Scores.Min();
142-
var maxR2 = r2Scores.Max();
143-
144-
logger?.LogInformation("Cross-validation complete:");
145-
logger?.LogInformation(" Mean R²: {MeanR2:F4} ± {StdDev:F4}", meanR2, stdDevR2);
146-
logger?.LogInformation(" Range: [{Min:F4}, {Max:F4}]", minR2, maxR2);
147-
logger?.LogInformation(" Kedro R²: {KedroR2:F4}", config.KedroReferenceR2Score);
148-
logger?.LogInformation(
149-
" Difference: {Diff:F4} ({Pct:F1}%)",
150-
Math.Abs(meanR2 - config.KedroReferenceR2Score),
151-
Math.Abs(meanR2 - config.KedroReferenceR2Score) / config.KedroReferenceR2Score * 100
152-
);
153-
154-
foreach (var fold in foldMetrics)
155-
{
156-
logger?.LogInformation(
157-
" Fold {Fold}: R²={R2:F4}, MAE={MAE:F2}, RMSE={RMSE:F2}",
158-
fold.FoldNumber,
159-
fold.R2Score,
160-
fold.MeanAbsoluteError,
161-
fold.RootMeanSquaredError
162-
);
163-
}
164-
165-
var results = new CrossValidationResults
166-
{
167-
FoldMetrics = foldMetrics,
168-
MeanR2Score = meanR2,
169-
StdDevR2Score = stdDevR2,
170-
MinR2Score = minR2,
171-
MaxR2Score = maxR2,
172-
NumFolds = config.NumFolds,
173-
KedroR2Score = config.KedroReferenceR2Score,
174-
DifferenceFromKedro = Math.Abs(meanR2 - config.KedroReferenceR2Score),
175-
};
105+
// Perform cross-validation
106+
var cvResults = mlContext.Regression.CrossValidate(
107+
allData,
108+
pipeline,
109+
numberOfFolds: config.NumFolds,
110+
labelColumnName: "Label"
111+
);
112+
113+
// Extract metrics from each fold
114+
var foldMetrics = cvResults
115+
.Select(
116+
(result, index) =>
117+
new FoldMetric
118+
{
119+
FoldNumber = index + 1,
120+
R2Score = result.Metrics.RSquared,
121+
MeanAbsoluteError = result.Metrics.MeanAbsoluteError,
122+
RootMeanSquaredError = result.Metrics.RootMeanSquaredError,
123+
LossFunctionValue = result.Metrics.LossFunction,
124+
}
125+
)
126+
.ToList();
127+
128+
// Calculate statistics
129+
var r2Scores = foldMetrics.Select(f => f.R2Score).ToList();
130+
var meanR2 = r2Scores.Average();
131+
var stdDevR2 = Math.Sqrt(r2Scores.Select(x => Math.Pow(x - meanR2, 2)).Average());
132+
var minR2 = r2Scores.Min();
133+
var maxR2 = r2Scores.Max();
134+
135+
foreach (var fold in foldMetrics)
136+
{
137+
_ = fold; // suppress unused warning if needed
138+
}
176139

177-
return results;
140+
var results = new CrossValidationResults
141+
{
142+
FoldMetrics = foldMetrics,
143+
MeanR2Score = meanR2,
144+
StdDevR2Score = stdDevR2,
145+
MinR2Score = minR2,
146+
MaxR2Score = maxR2,
147+
NumFolds = config.NumFolds,
148+
KedroR2Score = config.KedroReferenceR2Score,
149+
DifferenceFromKedro = Math.Abs(meanR2 - config.KedroReferenceR2Score),
178150
};
151+
152+
return results;
179153
}
180154
}

examples/advanced/KedroSpaceflights.Custom/Flows/DataScience/DataScienceFlow.cs

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,15 @@ namespace KedroSpaceflights.Custom.Flows.DataScience;
3333
/// </summary>
3434
public static class DataScienceFlow
3535
{
36-
/// <summary>
37-
/// Parameters for the data science pipeline nodes.
38-
/// </summary>
39-
public record Params
40-
{
41-
/// <summary>
42-
/// Options for model training.
43-
/// </summary>
44-
public CreateTestTrainSplitStep.TestTrainSplitParams ModelParams { get; init; } = new();
45-
}
46-
47-
public static Flow Create(Catalog catalog, Params parameters)
36+
public static Flow Create(Catalog catalog, FlowConfig config)
4837
{
4938
return FlowBuilder.CreateFlow(pipeline =>
5039
{
5140
// Step 1: Split data into train/test sets (single input → multi-output)
5241
pipeline.AddStep(
5342
label: "CreateTestTrainSplitDatasets",
54-
transform: CreateTestTrainSplitStep.Create(parameters: parameters.ModelParams),
55-
input: catalog.ModelInputTable,
43+
transform: CreateTestTrainSplitStep.Create,
44+
input: (catalog.ModelInputTable, config.ModelParams),
5645
output: (catalog.XTrain, catalog.XTest, catalog.YTrain, catalog.YTest)
5746
);
5847

0 commit comments

Comments
 (0)