Skip to content

Commit c7b584d

Browse files
authored
feat: add configurable missing data strategies and expanded tests (#10)
- Introduce a missing_strategy parameter for Rasch, 2PL, 3PL models (:ignore, :treat_as_incorrect, :treat_as_correct). - Add logic to handle nil responses according to the chosen strategy or skip them when ignoring. - Expand RSpec tests to cover repeated fitting, deterministic seeds, large random datasets, and missing data strategies. - Minor code cleanup and improved documentation around usage.
1 parent 5b02354 commit c7b584d

8 files changed

Lines changed: 348 additions & 77 deletions

.rubocop.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,6 @@ Metrics/CyclomaticComplexity:
3535

3636
Metrics/PerceivedComplexity:
3737
Enabled: false
38+
39+
Style/HashLikeCase:
40+
Enabled: false

lib/irt_ruby.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# frozen_string_literal: true
22

33
require "irt_ruby/version"
4+
require "matrix"
45
require "irt_ruby/rasch_model"
56
require "irt_ruby/two_parameter_model"
67
require "irt_ruby/three_parameter_model"

lib/irt_ruby/rasch_model.rb

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,70 @@
11
# frozen_string_literal: true
22

3-
require "matrix"
4-
53
module IrtRuby
64
# A class representing the Rasch model for Item Response Theory (ability - difficulty).
75
# Incorporates:
86
# - Adaptive learning rate
97
# - Missing data handling (skip nil)
108
# - Multiple convergence checks (log-likelihood + parameter updates)
119
class RaschModel
12-
def initialize(data, max_iter: 1000, tolerance: 1e-6, param_tolerance: 1e-6,
13-
learning_rate: 0.01, decay_factor: 0.5)
10+
MISSING_STRATEGIES = %i[ignore treat_as_incorrect treat_as_correct].freeze
11+
12+
def initialize(data,
13+
max_iter: 1000,
14+
tolerance: 1e-6,
15+
param_tolerance: 1e-6,
16+
learning_rate: 0.01,
17+
decay_factor: 0.5,
18+
missing_strategy: :ignore)
1419
# data: A Matrix or array-of-arrays of responses (0/1 or nil for missing).
15-
# Rows = respondents, Columns = items.
20+
# missing_strategy: :ignore (skip), :treat_as_incorrect, :treat_as_correct
1621

1722
@data = data
1823
@data_array = data.to_a
1924
num_rows = @data_array.size
2025
num_cols = @data_array.first.size
2126

27+
raise ArgumentError, "missing_strategy must be one of #{MISSING_STRATEGIES}" unless MISSING_STRATEGIES.include?(missing_strategy)
28+
29+
@missing_strategy = missing_strategy
30+
2231
# Initialize parameters near zero
2332
@abilities = Array.new(num_rows) { rand(-0.25..0.25) }
2433
@difficulties = Array.new(num_cols) { rand(-0.25..0.25) }
2534

26-
@max_iter = max_iter
27-
@tolerance = tolerance
35+
@max_iter = max_iter
36+
@tolerance = tolerance
2837
@param_tolerance = param_tolerance
29-
@learning_rate = learning_rate
30-
@decay_factor = decay_factor
38+
@learning_rate = learning_rate
39+
@decay_factor = decay_factor
3140
end
3241

3342
def sigmoid(x)
3443
1.0 / (1.0 + Math.exp(-x))
3544
end
3645

46+
def resolve_missing(resp)
47+
return [resp, false] unless resp.nil?
48+
49+
case @missing_strategy
50+
when :ignore
51+
[nil, true]
52+
when :treat_as_incorrect
53+
[0, false]
54+
when :treat_as_correct
55+
[1, false]
56+
end
57+
end
58+
3759
def log_likelihood
3860
total_ll = 0.0
3961
@data_array.each_with_index do |row, i|
4062
row.each_with_index do |resp, j|
41-
next if resp.nil?
63+
value, skip = resolve_missing(resp)
64+
next if skip
4265

4366
prob = sigmoid(@abilities[i] - @difficulties[j])
44-
total_ll += if resp == 1
67+
total_ll += if value == 1
4568
Math.log(prob + 1e-15)
4669
else
4770
Math.log((1 - prob) + 1e-15)
@@ -57,10 +80,11 @@ def compute_gradient
5780

5881
@data_array.each_with_index do |row, i|
5982
row.each_with_index do |resp, j|
60-
next if resp.nil?
83+
value, skip = resolve_missing(resp)
84+
next if skip
6185

6286
prob = sigmoid(@abilities[i] - @difficulties[j])
63-
error = resp - prob
87+
error = value - prob
6488

6589
grad_abilities[i] += error
6690
grad_difficulties[j] -= error
@@ -102,18 +126,17 @@ def fit
102126
@max_iter.times do
103127
grad_abilities, grad_difficulties = compute_gradient
104128

105-
old_abilities, old_difficulties = apply_gradient_update(grad_abilities, grad_difficulties)
129+
old_a, old_d = apply_gradient_update(grad_abilities, grad_difficulties)
106130

107-
current_ll = log_likelihood
108-
param_delta = average_param_update(old_abilities, old_difficulties)
131+
current_ll = log_likelihood
132+
param_delta = average_param_update(old_a, old_d)
109133

110134
if current_ll < prev_ll
111-
@abilities = old_abilities
112-
@difficulties = old_difficulties
135+
@abilities = old_a
136+
@difficulties = old_d
113137
@learning_rate *= @decay_factor
114138
else
115139
ll_diff = (current_ll - prev_ll).abs
116-
117140
break if ll_diff < @tolerance && param_delta < @param_tolerance
118141

119142
prev_ll = current_ll

lib/irt_ruby/three_parameter_model.rb

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# frozen_string_literal: true
22

3-
require "matrix"
4-
53
module IrtRuby
64
# A class representing the Three-Parameter model (3PL) for Item Response Theory.
75
# Incorporates:
@@ -11,14 +9,25 @@ module IrtRuby
119
# - Multiple convergence checks
1210
# - Separate gradient calculation & updates
1311
class ThreeParameterModel
14-
def initialize(data, max_iter: 1000, tolerance: 1e-6, param_tolerance: 1e-6,
15-
learning_rate: 0.01, decay_factor: 0.5)
12+
MISSING_STRATEGIES = %i[ignore treat_as_incorrect treat_as_correct].freeze
13+
14+
def initialize(data,
15+
max_iter: 1000,
16+
tolerance: 1e-6,
17+
param_tolerance: 1e-6,
18+
learning_rate: 0.01,
19+
decay_factor: 0.5,
20+
missing_strategy: :ignore)
1621
@data = data
1722
@data_array = data.to_a
1823
num_rows = @data_array.size
1924
num_cols = @data_array.first.size
2025

21-
# Typical initialization for 3PL
26+
raise ArgumentError, "missing_strategy must be one of #{MISSING_STRATEGIES}" unless MISSING_STRATEGIES.include?(missing_strategy)
27+
28+
@missing_strategy = missing_strategy
29+
30+
# Initialize parameters
2231
@abilities = Array.new(num_rows) { rand(-0.25..0.25) }
2332
@difficulties = Array.new(num_cols) { rand(-0.25..0.25) }
2433
@discriminations = Array.new(num_cols) { rand(0.5..1.5) }
@@ -40,15 +49,32 @@ def probability(theta, a, b, c)
4049
c + (1.0 - c) * sigmoid(a * (theta - b))
4150
end
4251

52+
def resolve_missing(resp)
53+
return [resp, false] unless resp.nil?
54+
55+
case @missing_strategy
56+
when :ignore
57+
[nil, true]
58+
when :treat_as_incorrect
59+
[0, false]
60+
when :treat_as_correct
61+
[1, false]
62+
end
63+
end
64+
4365
def log_likelihood
4466
ll = 0.0
4567
@data_array.each_with_index do |row, i|
4668
row.each_with_index do |resp, j|
47-
next if resp.nil?
69+
value, skip = resolve_missing(resp)
70+
next if skip
4871

49-
prob = probability(@abilities[i], @discriminations[j],
50-
@difficulties[j], @guessings[j])
51-
ll += if resp == 1
72+
prob = probability(@abilities[i],
73+
@discriminations[j],
74+
@difficulties[j],
75+
@guessings[j])
76+
77+
ll += if value == 1
5278
Math.log(prob + 1e-15)
5379
else
5480
Math.log((1 - prob) + 1e-15)
@@ -66,32 +92,33 @@ def compute_gradient
6692

6793
@data_array.each_with_index do |row, i|
6894
row.each_with_index do |resp, j|
69-
next if resp.nil?
95+
value, skip = resolve_missing(resp)
96+
next if skip
7097

7198
theta = @abilities[i]
7299
a = @discriminations[j]
73100
b = @difficulties[j]
74101
c = @guessings[j]
75102

76103
prob = probability(theta, a, b, c)
77-
error = resp - prob
104+
error = value - prob
78105

79-
grad_abilities[i] += error * a * (1 - c)
80-
grad_difficulties[j] -= error * a * (1 - c)
106+
grad_abilities[i] += error * a * (1 - c)
107+
grad_difficulties[j] -= error * a * (1 - c)
81108
grad_discriminations[j] += error * (theta - b) * (1 - c)
82109

83-
grad_guessings[j] += error * 1.0
110+
grad_guessings[j] += error * 1.0
84111
end
85112
end
86113

87114
[grad_abilities, grad_difficulties, grad_discriminations, grad_guessings]
88115
end
89116

90117
def apply_gradient_update(ga, gd, gdisc, gc)
91-
old_abilities = @abilities.dup
92-
old_difficulties = @difficulties.dup
93-
old_discriminations = @discriminations.dup
94-
old_guessings = @guessings.dup
118+
old_a = @abilities.dup
119+
old_d = @difficulties.dup
120+
old_disc = @discriminations.dup
121+
old_c = @guessings.dup
95122

96123
@abilities.each_index do |i|
97124
@abilities[i] += @learning_rate * ga[i]
@@ -113,23 +140,15 @@ def apply_gradient_update(ga, gd, gdisc, gc)
113140
@guessings[j] = 0.35 if @guessings[j] > 0.35
114141
end
115142

116-
[old_abilities, old_difficulties, old_discriminations, old_guessings]
143+
[old_a, old_d, old_disc, old_c]
117144
end
118145

119146
def average_param_update(old_a, old_d, old_disc, old_c)
120147
deltas = []
121-
@abilities.each_with_index do |x, i|
122-
deltas << (x - old_a[i]).abs
123-
end
124-
@difficulties.each_with_index do |x, j|
125-
deltas << (x - old_d[j]).abs
126-
end
127-
@discriminations.each_with_index do |x, j|
128-
deltas << (x - old_disc[j]).abs
129-
end
130-
@guessings.each_with_index do |x, j|
131-
deltas << (x - old_c[j]).abs
132-
end
148+
@abilities.each_with_index { |x, i| deltas << (x - old_a[i]).abs }
149+
@difficulties.each_with_index { |x, j| deltas << (x - old_d[j]).abs }
150+
@discriminations.each_with_index { |x, j| deltas << (x - old_disc[j]).abs }
151+
@guessings.each_with_index { |x, j| deltas << (x - old_c[j]).abs }
133152
deltas.sum / deltas.size
134153
end
135154

@@ -140,15 +159,14 @@ def fit
140159
ga, gd, gdisc, gc = compute_gradient
141160
old_a, old_d, old_disc, old_c = apply_gradient_update(ga, gd, gdisc, gc)
142161

143-
curr_ll = log_likelihood
162+
curr_ll = log_likelihood
144163
param_delta = average_param_update(old_a, old_d, old_disc, old_c)
145164

146165
if curr_ll < prev_ll
147166
@abilities = old_a
148167
@difficulties = old_d
149168
@discriminations = old_disc
150169
@guessings = old_c
151-
152170
@learning_rate *= @decay_factor
153171
else
154172
ll_diff = (curr_ll - prev_ll).abs

0 commit comments

Comments
 (0)