-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdtype_test.go
150 lines (134 loc) · 3.32 KB
/
dtype_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package dtype
import (
"reflect"
"testing"
)
type Float16 uint16
func TestRegisterType(t *testing.T) {
dt := Dtype{reflect.TypeOf(Float16(0))}
RegisterFloat(dt)
if err := typeclassCheck(dt, floatTypes); err != nil {
t.Errorf("Expected %v to be in floatTypes: %v", dt, err)
}
if err := typeclassCheck(dt, numberTypes); err != nil {
t.Errorf("Expected %v to be in numberTypes: %v", dt, err)
}
if err := typeclassCheck(dt, ordTypes); err != nil {
t.Errorf("Expected %v to be in ordTypes: %v", dt, err)
}
if err := typeclassCheck(dt, eqTypes); err != nil {
t.Errorf("Expected %v to be in eqTypes: %v", dt, err)
}
}
func TestDtypeConversions(t *testing.T) {
for k, v := range reverseNumpyDtypes {
if npdt, err := v.NumpyDtype(); npdt != k {
t.Errorf("Expected %v to return numpy dtype of %q. Got %q instead", v, k, npdt)
} else if err != nil {
t.Errorf("Error: %v", err)
}
}
dt := Dtype{reflect.TypeOf(Float16(0))}
if _, err := dt.NumpyDtype(); err == nil {
t.Errorf("Expected an error when passing in type unknown to np")
}
for k, v := range numpyDtypes {
if dt, err := FromNumpyDtype(v); dt != k {
// special cases
if Int.Size() == 4 && v == "i4" && dt == Int {
continue
}
if Int.Size() == 8 && v == "i8" && dt == Int {
continue
}
if Uint.Size() == 4 && v == "u4" && dt == Uint {
continue
}
if Uint.Size() == 8 && v == "u8" && dt == Uint {
continue
}
t.Errorf("Expected %q to return %v. Got %v instead", v, k, dt)
} else if err != nil {
t.Errorf("Error: %v", err)
}
}
if _, err := FromNumpyDtype("EDIUH"); err == nil {
t.Error("Expected error when nonsense is passed into fromNumpyDtype")
}
}
func TestAllTypes(t *testing.T) {
for _, tc := range []*typeclass{
specializedTypes,
addableTypes,
numberTypes,
ordTypes,
eqTypes,
signedTypes,
unsignedTypes,
signedNonComplexTypes,
floatTypes,
complexTypes,
floatcmplxTypes,
nonComplexNumberTypes,
generatableTypes,
} {
tc.Lock()
for _, typ := range tc.set {
if ID(typ) == -1 {
t.Errorf("Dtype %v has no ID in allTypes. It is not properly registered", typ)
}
}
tc.Unlock()
}
}
func TestFindByName(t *testing.T) {
dt, err := FindByName("float64")
if err != nil {
t.Errorf("Expected \"float64\" to be found")
}
if dt != Float64 {
t.Errorf("Got a different dtype than expected")
}
_, err = FindByName("f00b4rb4z")
if err == nil {
t.Errorf("Expected the Dtype named \"f00b4rb4z\" to not be found ")
}
}
func TestRegisterFloat(t *testing.T) {
// this is a repeat test, to test repeated additions of a given dtype
dt := Dtype{reflect.TypeOf(Float16(0))}
RegisterFloat(dt)
RegisterFloat(dt)
var count int
floatTypes.Lock()
for _, ft := range floatTypes.set {
if ft == dt {
count++
}
}
floatTypes.Unlock()
if count != 1 {
t.Errorf("Expected Float16 to only exist once in the float types set")
}
}
func TestTypeClassCheck(t *testing.T) {
dt := Float64
cases := []struct {
TypeClass
willerr bool
}{
{Number, false},
{maxtypeclass, true},
{Unsigned, true},
{-1, false},
}
for _, tc := range cases {
err := TypeClassCheck(dt, tc.TypeClass)
switch {
case tc.willerr && err == nil:
t.Errorf("Expected Float64 in %v to error.", tc.TypeClass)
case !tc.willerr && err != nil:
t.Errorf("Expected Float64 in %v to not error", tc.TypeClass)
}
}
}