|
2 | 2 | using Flowthru.Core.Steps; |
3 | 3 | using KedroSpaceflights.Custom.Data._03_Primary.Schemas; |
4 | 4 | using KedroSpaceflights.Custom.Data._06_Reporting.Schemas; |
5 | | -using Microsoft.Extensions.Logging; |
6 | 5 | using Microsoft.ML; |
7 | 6 |
|
8 | 7 | namespace KedroSpaceflights.Custom.Flows.DataEvaluation.Steps; |
@@ -43,138 +42,113 @@ public record Params |
43 | 42 | public float KedroReferenceR2Score { get; init; } |
44 | 43 | } |
45 | 44 |
|
46 | | - public static Func<IEnumerable<ModelInputSchema>, Task<CrossValidationResults>> Create( |
47 | | - Params? parameters = null, |
48 | | - ILogger? logger = null |
| 45 | + public static async Task<CrossValidationResults> Create( |
| 46 | + (IEnumerable<ModelInputSchema> Data, Params Options) input |
49 | 47 | ) |
50 | 48 | { |
51 | | - var config = parameters ?? new Params(); |
| 49 | + var (rawInput, config) = input; |
| 50 | + var data = rawInput.ToList(); |
52 | 51 |
|
53 | | - return async (input) => |
54 | | - { |
55 | | - var data = input.ToList(); |
56 | | - logger?.LogInformation("Starting cross-validation with {Folds} folds", config.NumFolds); |
57 | | - |
58 | | - // Convert to feature rows |
59 | | - var featureRows = data.Select(row => new FeatureRow |
60 | | - { |
61 | | - Engines = (float)row.Engines, |
62 | | - PassengerCapacity = (float)row.PassengerCapacity, |
63 | | - Crew = (float)row.Crew, |
64 | | - DCheckComplete = row.DCheckComplete, |
65 | | - IataApproved = row.IataApproved, |
66 | | - CompanyRating = (float)row.CompanyRating, |
67 | | - ReviewScoresRating = (float)row.ReviewScoresRating, |
68 | | - Price = (float)row.Price, |
69 | | - }) |
70 | | - .ToList(); |
71 | | - |
72 | | - var mlContext = new MLContext(seed: config.BaseSeed); |
73 | | - var allData = mlContext.Data.LoadFromEnumerable(featureRows); |
74 | | - |
75 | | - // Define the ML pipeline (same as TrainModelStep) |
76 | | - var pipeline = mlContext |
77 | | - .Transforms.CopyColumns( |
78 | | - outputColumnName: "Label", |
79 | | - inputColumnName: nameof(FeatureRow.Price) |
| 52 | + // Convert to feature rows |
| 53 | + var featureRows = data.Select(row => new FeatureRow |
| 54 | + { |
| 55 | + Engines = (float)row.Engines, |
| 56 | + PassengerCapacity = (float)row.PassengerCapacity, |
| 57 | + Crew = (float)row.Crew, |
| 58 | + DCheckComplete = row.DCheckComplete, |
| 59 | + IataApproved = row.IataApproved, |
| 60 | + CompanyRating = (float)row.CompanyRating, |
| 61 | + ReviewScoresRating = (float)row.ReviewScoresRating, |
| 62 | + Price = (float)row.Price, |
| 63 | + }) |
| 64 | + .ToList(); |
| 65 | + |
| 66 | + var mlContext = new MLContext(seed: config.BaseSeed); |
| 67 | + var allData = mlContext.Data.LoadFromEnumerable(featureRows); |
| 68 | + |
| 69 | + // Define the ML pipeline (same as TrainModelStep) |
| 70 | + var pipeline = mlContext |
| 71 | + .Transforms.CopyColumns(outputColumnName: "Label", inputColumnName: nameof(FeatureRow.Price)) |
| 72 | + .Append( |
| 73 | + mlContext.Transforms.Categorical.OneHotEncoding( |
| 74 | + outputColumnName: "DCheckCompleteEncoded", |
| 75 | + inputColumnName: nameof(FeatureRow.DCheckComplete) |
80 | 76 | ) |
81 | | - .Append( |
82 | | - mlContext.Transforms.Categorical.OneHotEncoding( |
83 | | - outputColumnName: "DCheckCompleteEncoded", |
84 | | - inputColumnName: nameof(FeatureRow.DCheckComplete) |
85 | | - ) |
| 77 | + ) |
| 78 | + .Append( |
| 79 | + mlContext.Transforms.Categorical.OneHotEncoding( |
| 80 | + outputColumnName: "IataApprovedEncoded", |
| 81 | + inputColumnName: nameof(FeatureRow.IataApproved) |
86 | 82 | ) |
87 | | - .Append( |
88 | | - mlContext.Transforms.Categorical.OneHotEncoding( |
89 | | - outputColumnName: "IataApprovedEncoded", |
90 | | - inputColumnName: nameof(FeatureRow.IataApproved) |
91 | | - ) |
| 83 | + ) |
| 84 | + .Append( |
| 85 | + mlContext.Transforms.Concatenate( |
| 86 | + "Features", |
| 87 | + nameof(FeatureRow.Engines), |
| 88 | + nameof(FeatureRow.PassengerCapacity), |
| 89 | + nameof(FeatureRow.Crew), |
| 90 | + "DCheckCompleteEncoded", |
| 91 | + "IataApprovedEncoded", |
| 92 | + nameof(FeatureRow.CompanyRating), |
| 93 | + nameof(FeatureRow.ReviewScoresRating) |
92 | 94 | ) |
93 | | - .Append( |
94 | | - mlContext.Transforms.Concatenate( |
95 | | - "Features", |
96 | | - nameof(FeatureRow.Engines), |
97 | | - nameof(FeatureRow.PassengerCapacity), |
98 | | - nameof(FeatureRow.Crew), |
99 | | - "DCheckCompleteEncoded", |
100 | | - "IataApprovedEncoded", |
101 | | - nameof(FeatureRow.CompanyRating), |
102 | | - nameof(FeatureRow.ReviewScoresRating) |
103 | | - ) |
| 95 | + ) |
| 96 | + .Append(mlContext.Transforms.NormalizeMinMax("Features")) |
| 97 | + .Append( |
| 98 | + mlContext.Regression.Trainers.OnlineGradientDescent( |
| 99 | + labelColumnName: "Label", |
| 100 | + featureColumnName: "Features", |
| 101 | + numberOfIterations: 1000 |
104 | 102 | ) |
105 | | - .Append(mlContext.Transforms.NormalizeMinMax("Features")) |
106 | | - .Append( |
107 | | - mlContext.Regression.Trainers.OnlineGradientDescent( |
108 | | - labelColumnName: "Label", |
109 | | - featureColumnName: "Features", |
110 | | - numberOfIterations: 1000 |
111 | | - ) |
112 | | - ); |
113 | | - |
114 | | - // Perform cross-validation |
115 | | - var cvResults = mlContext.Regression.CrossValidate( |
116 | | - allData, |
117 | | - pipeline, |
118 | | - numberOfFolds: config.NumFolds, |
119 | | - labelColumnName: "Label" |
120 | 103 | ); |
121 | 104 |
|
122 | | - // Extract metrics from each fold |
123 | | - var foldMetrics = cvResults |
124 | | - .Select( |
125 | | - (result, index) => |
126 | | - new FoldMetric |
127 | | - { |
128 | | - FoldNumber = index + 1, |
129 | | - R2Score = result.Metrics.RSquared, |
130 | | - MeanAbsoluteError = result.Metrics.MeanAbsoluteError, |
131 | | - RootMeanSquaredError = result.Metrics.RootMeanSquaredError, |
132 | | - LossFunctionValue = result.Metrics.LossFunction, |
133 | | - } |
134 | | - ) |
135 | | - .ToList(); |
136 | | - |
137 | | - // Calculate statistics |
138 | | - var r2Scores = foldMetrics.Select(f => f.R2Score).ToList(); |
139 | | - var meanR2 = r2Scores.Average(); |
140 | | - var stdDevR2 = Math.Sqrt(r2Scores.Select(x => Math.Pow(x - meanR2, 2)).Average()); |
141 | | - var minR2 = r2Scores.Min(); |
142 | | - var maxR2 = r2Scores.Max(); |
143 | | - |
144 | | - logger?.LogInformation("Cross-validation complete:"); |
145 | | - logger?.LogInformation(" Mean R²: {MeanR2:F4} ± {StdDev:F4}", meanR2, stdDevR2); |
146 | | - logger?.LogInformation(" Range: [{Min:F4}, {Max:F4}]", minR2, maxR2); |
147 | | - logger?.LogInformation(" Kedro R²: {KedroR2:F4}", config.KedroReferenceR2Score); |
148 | | - logger?.LogInformation( |
149 | | - " Difference: {Diff:F4} ({Pct:F1}%)", |
150 | | - Math.Abs(meanR2 - config.KedroReferenceR2Score), |
151 | | - Math.Abs(meanR2 - config.KedroReferenceR2Score) / config.KedroReferenceR2Score * 100 |
152 | | - ); |
153 | | - |
154 | | - foreach (var fold in foldMetrics) |
155 | | - { |
156 | | - logger?.LogInformation( |
157 | | - " Fold {Fold}: R²={R2:F4}, MAE={MAE:F2}, RMSE={RMSE:F2}", |
158 | | - fold.FoldNumber, |
159 | | - fold.R2Score, |
160 | | - fold.MeanAbsoluteError, |
161 | | - fold.RootMeanSquaredError |
162 | | - ); |
163 | | - } |
164 | | - |
165 | | - var results = new CrossValidationResults |
166 | | - { |
167 | | - FoldMetrics = foldMetrics, |
168 | | - MeanR2Score = meanR2, |
169 | | - StdDevR2Score = stdDevR2, |
170 | | - MinR2Score = minR2, |
171 | | - MaxR2Score = maxR2, |
172 | | - NumFolds = config.NumFolds, |
173 | | - KedroR2Score = config.KedroReferenceR2Score, |
174 | | - DifferenceFromKedro = Math.Abs(meanR2 - config.KedroReferenceR2Score), |
175 | | - }; |
| 105 | + // Perform cross-validation |
| 106 | + var cvResults = mlContext.Regression.CrossValidate( |
| 107 | + allData, |
| 108 | + pipeline, |
| 109 | + numberOfFolds: config.NumFolds, |
| 110 | + labelColumnName: "Label" |
| 111 | + ); |
| 112 | + |
| 113 | + // Extract metrics from each fold |
| 114 | + var foldMetrics = cvResults |
| 115 | + .Select( |
| 116 | + (result, index) => |
| 117 | + new FoldMetric |
| 118 | + { |
| 119 | + FoldNumber = index + 1, |
| 120 | + R2Score = result.Metrics.RSquared, |
| 121 | + MeanAbsoluteError = result.Metrics.MeanAbsoluteError, |
| 122 | + RootMeanSquaredError = result.Metrics.RootMeanSquaredError, |
| 123 | + LossFunctionValue = result.Metrics.LossFunction, |
| 124 | + } |
| 125 | + ) |
| 126 | + .ToList(); |
| 127 | + |
| 128 | + // Calculate statistics |
| 129 | + var r2Scores = foldMetrics.Select(f => f.R2Score).ToList(); |
| 130 | + var meanR2 = r2Scores.Average(); |
| 131 | + var stdDevR2 = Math.Sqrt(r2Scores.Select(x => Math.Pow(x - meanR2, 2)).Average()); |
| 132 | + var minR2 = r2Scores.Min(); |
| 133 | + var maxR2 = r2Scores.Max(); |
| 134 | + |
| 135 | + foreach (var fold in foldMetrics) |
| 136 | + { |
| 137 | + _ = fold; // suppress unused warning if needed |
| 138 | + } |
176 | 139 |
|
177 | | - return results; |
| 140 | + var results = new CrossValidationResults |
| 141 | + { |
| 142 | + FoldMetrics = foldMetrics, |
| 143 | + MeanR2Score = meanR2, |
| 144 | + StdDevR2Score = stdDevR2, |
| 145 | + MinR2Score = minR2, |
| 146 | + MaxR2Score = maxR2, |
| 147 | + NumFolds = config.NumFolds, |
| 148 | + KedroR2Score = config.KedroReferenceR2Score, |
| 149 | + DifferenceFromKedro = Math.Abs(meanR2 - config.KedroReferenceR2Score), |
178 | 150 | }; |
| 151 | + |
| 152 | + return results; |
179 | 153 | } |
180 | 154 | } |
0 commit comments