|
38 | 38 | (: array-axis-reduce (All (A B) ((Array A) Integer (Index (Integer -> A) -> B) -> (Array B))))
|
39 | 39 | (define (array-axis-reduce arr k f)
|
40 | 40 | (let ([k (check-array-axis 'array-axis-reduce arr k)])
|
41 |
| - (unsafe-array-axis-reduce |
42 |
| - arr k |
43 |
| - (λ: ([dk : Index] [proc : (Index -> A)]) |
44 |
| - (: safe-proc (Integer -> A)) |
45 |
| - (define (safe-proc jk) |
46 |
| - (cond [(or (jk . < . 0) (jk . >= . dk)) |
47 |
| - (raise-argument-error 'array-axis-reduce (format "Index < ~a" dk) jk)] |
48 |
| - [else (proc jk)])) |
49 |
| - (f dk safe-proc))))) |
| 41 | + (array-default-strict |
| 42 | + (unsafe-array-axis-reduce |
| 43 | + arr k |
| 44 | + (λ: ([dk : Index] [proc : (Index -> A)]) |
| 45 | + (: safe-proc (Integer -> A)) |
| 46 | + (define (safe-proc jk) |
| 47 | + (cond [(or (jk . < . 0) (jk . >= . dk)) |
| 48 | + (raise-argument-error 'array-axis-reduce (format "Index < ~a" dk) jk)] |
| 49 | + [else (proc jk)])) |
| 50 | + (f dk safe-proc)))))) |
50 | 51 |
|
51 | 52 | (: array-axis-fold/init (All (A B) ((Array A) Integer (A B -> B) B -> (Array B))))
|
52 | 53 | (define (array-axis-fold/init arr k f init)
|
|
72 | 73 | ((Array A) Integer (A B -> B) B -> (Array B)))))
|
73 | 74 | (define array-axis-fold
|
74 | 75 | (case-lambda
|
75 |
| - [(arr k f) (array-axis-fold/no-init arr k f)] |
76 |
| - [(arr k f init) (array-axis-fold/init arr k f init)])) |
| 76 | + [(arr k f) (array-default-strict (array-axis-fold/no-init arr k f))] |
| 77 | + [(arr k f init) (array-default-strict (array-axis-fold/init arr k f init))])) |
77 | 78 |
|
78 | 79 | ;; ===================================================================================================
|
79 | 80 | ;; Whole-array folds
|
|
93 | 94 | (define array-all-fold
|
94 | 95 | (case-lambda
|
95 | 96 | [(arr f)
|
96 |
| - (array-ref (array-fold arr (λ: ([arr : (Array A)] [k : Index]) |
97 |
| - (array-axis-fold arr k f))) |
98 |
| - #())] |
| 97 | + ;; Though `f' is folded over multiple axes, each element of `arr' is referred to only once, so |
| 98 | + ;; turning strictness off can't hurt performance |
| 99 | + (parameterize ([array-strictness #f]) |
| 100 | + (array-ref (array-fold arr (λ: ([arr : (Array A)] [k : Index]) |
| 101 | + (array-axis-fold arr k f))) |
| 102 | + #()))] |
99 | 103 | [(arr f init)
|
100 |
| - (array-ref (array-fold arr (λ: ([arr : (Array A)] [k : Index]) |
101 |
| - (array-axis-fold arr k f init))) |
102 |
| - #())])) |
| 104 | + ;; See above for why non-strictness is okay |
| 105 | + (parameterize ([array-strictness #f]) |
| 106 | + (array-ref (array-fold arr (λ: ([arr : (Array A)] [k : Index]) |
| 107 | + (array-axis-fold arr k f init))) |
| 108 | + #()))])) |
103 | 109 |
|
104 | 110 | ) ; begin-encourage-inline
|
105 | 111 |
|
|
109 | 115 | (: array-axis-count (All (A) ((Array A) Integer (A -> Any) -> (Array Index))))
|
110 | 116 | (define (array-axis-count arr k pred?)
|
111 | 117 | (let ([k (check-array-axis 'array-axis-count arr k)])
|
112 |
| - (unsafe-array-axis-reduce |
113 |
| - arr k (λ: ([dk : Index] [proc : (Index -> A)]) |
114 |
| - (let: loop : Index ([jk : Nonnegative-Fixnum 0] [acc : Nonnegative-Fixnum 0]) |
115 |
| - (if (jk . fx< . dk) |
116 |
| - (cond [(pred? (proc jk)) (loop (fx+ jk 1) (unsafe-fx+ acc 1))] |
117 |
| - [else (loop (fx+ jk 1) acc)]) |
118 |
| - (assert acc index?))))))) |
| 118 | + (array-default-strict |
| 119 | + (unsafe-array-axis-reduce |
| 120 | + arr k (λ: ([dk : Index] [proc : (Index -> A)]) |
| 121 | + (let: loop : Index ([jk : Nonnegative-Fixnum 0] [acc : Nonnegative-Fixnum 0]) |
| 122 | + (if (jk . fx< . dk) |
| 123 | + (cond [(pred? (proc jk)) (loop (fx+ jk 1) (unsafe-fx+ acc 1))] |
| 124 | + [else (loop (fx+ jk 1) acc)]) |
| 125 | + (assert acc index?)))))))) |
119 | 126 |
|
120 | 127 | ;; ===================================================================================================
|
121 | 128 | ;; Short-cutting axis folds
|
122 | 129 |
|
123 | 130 | (: array-axis-and (All (A) ((Array A) Integer -> (Array (U A Boolean)))))
|
124 | 131 | (define (array-axis-and arr k)
|
125 | 132 | (let ([k (check-array-axis 'array-axis-and arr k)])
|
126 |
| - (unsafe-array-axis-reduce |
127 |
| - arr k (λ: ([dk : Index] [proc : (Index -> A)]) |
128 |
| - (let: loop : (U A Boolean) ([jk : Nonnegative-Fixnum 0] [acc : (U A Boolean) #t]) |
129 |
| - (cond [(jk . fx< . dk) (define v (and acc (proc jk))) |
130 |
| - (if v (loop (fx+ jk 1) v) v)] |
131 |
| - [else acc])))))) |
| 133 | + (array-default-strict |
| 134 | + (unsafe-array-axis-reduce |
| 135 | + arr k (λ: ([dk : Index] [proc : (Index -> A)]) |
| 136 | + (let: loop : (U A Boolean) ([jk : Nonnegative-Fixnum 0] [acc : (U A Boolean) #t]) |
| 137 | + (cond [(jk . fx< . dk) (define v (and acc (proc jk))) |
| 138 | + (if v (loop (fx+ jk 1) v) v)] |
| 139 | + [else acc]))))))) |
132 | 140 |
|
133 | 141 | (: array-axis-or (All (A) ((Array A) Integer -> (Array (U A #f)))))
|
134 | 142 | (define (array-axis-or arr k)
|
135 | 143 | (let ([k (check-array-axis 'array-axis-or arr k)])
|
136 |
| - (unsafe-array-axis-reduce |
137 |
| - arr k (λ: ([dk : Index] [proc : (Index -> A)]) |
138 |
| - (let: loop : (U A #f) ([jk : Nonnegative-Fixnum 0] [acc : (U A #f) #f]) |
139 |
| - (cond [(jk . fx< . dk) (define v (or acc (proc jk))) |
140 |
| - (if v v (loop (fx+ jk 1) v))] |
141 |
| - [else acc])))))) |
| 144 | + (array-default-strict |
| 145 | + (unsafe-array-axis-reduce |
| 146 | + arr k (λ: ([dk : Index] [proc : (Index -> A)]) |
| 147 | + (let: loop : (U A #f) ([jk : Nonnegative-Fixnum 0] [acc : (U A #f) #f]) |
| 148 | + (cond [(jk . fx< . dk) (define v (or acc (proc jk))) |
| 149 | + (if v v (loop (fx+ jk 1) v))] |
| 150 | + [else acc]))))))) |
142 | 151 |
|
143 | 152 | (: array-all-and (All (A B) ((Array A) -> (U A Boolean))))
|
144 | 153 | (define (array-all-and arr)
|
145 |
| - (array-ref ((inst array-fold (U A Boolean)) arr array-axis-and) #())) |
| 154 | + ;; See `array-all-fold' for why non-strictness is okay |
| 155 | + (parameterize ([array-strictness #f]) |
| 156 | + (array-ref ((inst array-fold (U A Boolean)) arr array-axis-and) #()))) |
146 | 157 |
|
147 | 158 | (: array-all-or (All (A B) ((Array A) -> (U A #f))))
|
148 | 159 | (define (array-all-or arr)
|
149 |
| - (array-ref ((inst array-fold (U A #f)) arr array-axis-or) #())) |
| 160 | + ;; See `array-all-fold' for why non-strictness is okay |
| 161 | + (parameterize ([array-strictness #f]) |
| 162 | + (array-ref ((inst array-fold (U A #f)) arr array-axis-or) #()))) |
150 | 163 |
|
151 | 164 | ;; ===================================================================================================
|
152 | 165 | ;; Other folds
|
|
156 | 169 | (define (array->list-array arr [k 0])
|
157 | 170 | (define dims (array-dims arr))
|
158 | 171 | (cond [(and (k . >= . 0) (k . < . dims))
|
159 |
| - (unsafe-array-axis-reduce arr k (inst build-list A))] |
| 172 | + (array-default-strict |
| 173 | + (unsafe-array-axis-reduce arr k (inst build-list A)))] |
160 | 174 | [else
|
161 | 175 | (raise-argument-error 'array->list-array (format "Index < ~a" dims) 1 arr k)]))
|
0 commit comments