Skip to content

Commit c582f71

Browse files
committed
introduce MLStyle.enum_matcher
1 parent 0b12b7a commit c582f71

5 files changed

Lines changed: 163 additions & 14 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLStyle"
22
uuid = "d8e11817-5142-5d16-987a-aa16d5891078"
33
authors = ["thautwarm <twshere@outlook.com>"]
4-
version = "0.4.16"
4+
version = "0.4.17"
55

66
[deps]
77

docs/syntax/pattern.md

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -350,24 +350,28 @@ You can extend following APIs for your pattern objects, to implement custom patt
350350
`MLStyle.pattern_unref(pat_obj, expr_to_pat, [:a, :b]`.
351351

352352
- `MLStyle.is_enum`
353-
353+
354354
In a pattern `[A, B]`, usually we think both `A` and `B` are capturing patterns. However, it is handy if we can have a pattern `A` whose match means comparing to the global variable `A`.
355355

356356
To achieve this, we provide `MLStyle.is_enum`.
357357
For a visible global variable `A`, if `MLStyle.is_enum(A) == true`, a symbol `A` will compile into a pattern with `MLStyle.pattern_uncall(A, expr_to_ast, [], [], [])`.
358358

359+
- `MLStyle.enum_matcher(E, value_to_match)`:
360+
361+
If `MLStyle.is_enum(E) == true`, we will call `MLStyle.enum_matcher(E, value_to_match)` to compile `E` into a pattern.
362+
359363
We present some examples for understandability:
360364

361365
### Support Pattern Matching for Julia Enums
362366

363367
```julia-console
364368
julia> using MLStyle
365-
julia> using MLStyle.AbstractPatterns: literal
366369
julia> @enum E E1 E2
367370
# mark E1, E2 as non-capturing patterns
368371
julia> MLStyle.is_enum(::E) = true
369-
# tell the compiler how to match E1, E2
370-
julia> MLStyle.pattern_uncall(e::E, _, _, _, _) = literal(e)
372+
# tell the compiler how to match E1 and E2
373+
# NOTE: make sure it evaluates to a boolean value!
374+
julia> MLStyle.enum_matcher(e::E, expr) = :($e === $expr)
371375
julia> x = E2
372376
julia> @match x begin
373377
E1 => "match E1!"
@@ -383,13 +387,13 @@ julia> @macroexpand @match x begin
383387
:(let
384388
var"##return#261" = nothing
385389
var"##263" = x
386-
if var"##263" === E1
390+
if E1 === var"##263"
387391
var"##return#261" = let
388392
"match E1!"
389393
end
390394
$(Expr(:symbolicgoto, Symbol("####final#262#264")))
391395
end
392-
if var"##263" === E2
396+
if E1 === var"##263"
393397
var"##return#261" = let
394398
"match E2!"
395399
end

src/MatchImpl.jl

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,19 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@compi
55
end
66

77
export is_enum,
8-
pattern_uncall, pattern_unref, pattern_unmacrocall, @switch, @case, @tryswitch, @match, @trymatch, Where, gen_match, gen_switch
8+
enum_matcher,
9+
pattern_uncall,
10+
pattern_unref,
11+
pattern_unmacrocall,
12+
@switch,
13+
@case,
14+
@tryswitch,
15+
@match,
16+
@trymatch,
17+
Where,
18+
gen_match,
19+
gen_switch
20+
921
export Q
1022
import MLStyle
1123
using MLStyle: mlstyle_report_deprecation_msg!
@@ -17,8 +29,70 @@ using MLStyle.AbstractPatterns
1729
using MLStyle.AbstractPatterns.BasicPatterns
1830
OptionalLn = Union{LineNumberNode, Nothing}
1931

32+
"""
33+
is_enum(EnumPattern)::Bool
34+
35+
Convert the pattern `EnumPattern` to `EnumPattern()`.
36+
37+
e.g.,
38+
```
39+
abstract type AbsS end
40+
struct S1 <: AbsS end
41+
struct S2 <: AbsS end
42+
MLStyle.pattern_uncall(::Type{S}, self, _, _, _) where {S<:AbsS} = literal(S())
43+
MLStyle.is_enum(::Type{<:AbsS}) = true
44+
45+
x = S1()
46+
@match x begin
47+
S2 => 1
48+
S1 => 2
49+
end
50+
```
51+
52+
"""
2053
is_enum(_)::Bool = false
21-
function pattern_uncall end
54+
55+
"""
56+
enum_matcher(Enum, value)::Expr
57+
58+
Generates the expression used to test if `value` is the case `Enum`.
59+
60+
NOTE that this only works when `is_enum(Enum)` is `true`!!!
61+
62+
@match V begin
63+
Enum => ...
64+
@end
65+
66+
Above single case matches when
67+
68+
1. `enum_matcher(Enum, ::Any)` is not defined and `V == Enum`.
69+
2. The expression generated from `enum_matcher(Enum, :V)`
70+
evaluates to `true` under the current module.
71+
72+
"""
73+
function enum_matcher end
74+
75+
struct _EnumCase{E}
76+
pattern::E
77+
end
78+
79+
function pattern_uncall(enumCase::_EnumCase{E}, self, type_params, type_args, args) where E
80+
isempty(type_params) || error("Enum type should not have type parameters!")
81+
isempty(type_args) || error("Enum type should not have type arguments!")
82+
isempty(args) || error("Enum type should not have arguments!")
83+
84+
let enumPattern = enumCase.pattern
85+
if hasmethod(MLStyle.enum_matcher, Tuple{E, Any})
86+
function via_enum_matcher(target, _, _)
87+
return MLStyle.enum_matcher(enumPattern, target)
88+
end
89+
guard(via_enum_matcher)
90+
else
91+
pattern_uncall(enumPattern, self, type_params, type_args, args)
92+
end
93+
end
94+
end
95+
2296
function pattern_unref end
2397
function pattern_unmacrocall(macro_func, self::Function, args::AbstractArray)
2498
@sswitch args begin
@@ -85,6 +159,12 @@ function guess_type_from_expr(m::Module, ex::Any, tps::Set{Symbol})
85159
end
86160
end
87161

162+
struct ModuleBoundedEx2tf <: Function
163+
m::Module
164+
end
165+
166+
@inline (self::ModuleBoundedEx2tf)(arg) = ex2tf(self.m, arg)
167+
88168
ex2tf(m::Module, @nospecialize(a)) = literal(a)
89169
ex2tf(m::Module, l::LineNumberNode) = wildcard
90170
ex2tf(m::Module, q::QuoteNode) = literal(q.value)
@@ -97,8 +177,8 @@ ex2tf(m::Module, n::Symbol) =
97177
else
98178
if isdefined(m, n)
99179
p = getfield(m, n)
100-
rec(x) = ex2tf(m, x)
101-
is_enum(p) && return pattern_uncall(p, rec, [], [], [])
180+
rec = ModuleBoundedEx2tf(m)
181+
is_enum(p) && return pattern_uncall(_EnumCase(p), rec, [], [], [])
102182
end
103183
P_capture(n)
104184
end
@@ -112,7 +192,8 @@ function ex2tf(m::Module, s::QuotePattern)
112192
end
113193

114194
function ex2tf(m::Module, w::Where)
115-
rec(x) = ex2tf(m, x)
195+
rec = ModuleBoundedEx2tf(m)
196+
116197
@sswitch w begin
117198
@case Where(; value = val, type = t, type_parameters = tps)
118199

@@ -170,7 +251,7 @@ end
170251

171252
function ex2tf(m::Module, ex::Expr)
172253
eval = m.eval
173-
rec(x) = ex2tf(m, x)
254+
rec = ModuleBoundedEx2tf(m)
174255

175256
@sswitch ex begin
176257
@case Expr(:||, args)

test/issues/154.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
using MLStyle
2+
import MLStyle.AbstractPatterns
3+
4+
abstract type Enum154 end
5+
6+
struct Enum154_1_Cons <: Enum154 end
7+
8+
struct Enum154_2_Cons <: Enum154
9+
x::Vector{Int}
10+
end
11+
MLStyle.@as_record Enum154_2_Cons
12+
13+
MLStyle.is_enum(::Enum154) = true
14+
MLStyle.enum_matcher(enum::Enum154, expr) = :($enum === $expr)
15+
16+
const Enum154_1 = Enum154_1_Cons()
17+
18+
function Base.:(==)(a::Enum154, b::Enum154)
19+
@match (a, b) begin
20+
(Enum154_1, Enum154_1) => true
21+
(Enum154_2_Cons(xs), Enum154_2_Cons(ys)) => xs == ys
22+
_ => false
23+
end
24+
end
25+
26+
# traditional behaviour
27+
28+
@enum JuliaEnum_154 begin
29+
JuliaEnum_154_a
30+
JuliaEnum_154_b
31+
JuliaEnum_154_c
32+
end
33+
34+
MLStyle.is_enum(::JuliaEnum_154) = true
35+
36+
MLStyle.pattern_uncall(a::JuliaEnum_154, ::Vararg) = MLStyle.AbstractPatterns.literal(a)
37+
38+
function eq_154(a, b)
39+
@match (a, b) begin
40+
(JuliaEnum_154_a, JuliaEnum_154_a) => true
41+
(JuliaEnum_154_b, JuliaEnum_154_b) => true
42+
(JuliaEnum_154_c, JuliaEnum_154_c) => true
43+
_ => false
44+
end
45+
end
46+
47+
@testset "issue 154" begin
48+
@testset "tag matching support" begin
49+
@test Enum154_1 == Enum154_1
50+
@test Enum154_2_Cons([1, 2, 3]) == Enum154_2_Cons([1, 2, 3])
51+
@test Enum154_2_Cons([1, 2, 3]) != Enum154_2_Cons([1, 2, 4])
52+
@test Enum154_1 != Enum154_2_Cons([1, 2, 3])
53+
end
54+
55+
@testset "traditional" begin
56+
@test eq_154(JuliaEnum_154_a, JuliaEnum_154_a)
57+
@test eq_154(JuliaEnum_154_b, JuliaEnum_154_b)
58+
@test eq_154(JuliaEnum_154_c, JuliaEnum_154_c)
59+
@test !eq_154(JuliaEnum_154_a, JuliaEnum_154_b)
60+
@test !eq_154(JuliaEnum_154_b, JuliaEnum_154_c)
61+
@test !eq_154(JuliaEnum_154_c, JuliaEnum_154_a)
62+
end
63+
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ MODULE = TestModule
5252

5353
@use GADT
5454

55-
include("issues/109.jl")
5655
include("when.jl")
5756
include("switch.jl")
5857
include("untyped_lam.jl")
@@ -80,5 +79,7 @@ include("MQuery/test.jl")
8079

8180
include("issues/87.jl")
8281
include("issues/62.jl")
82+
include("issues/109.jl")
83+
include("issues/154.jl")
8384

8485
end

0 commit comments

Comments
 (0)