Skip to content

Commit 54f029b

Browse files
Neil Torontormculpepper
Neil Toronto
authored andcommitted
Made arrays strict by default; please merge to release
* Added parameter `array-strictness', default #t * Added `array-default-strict!' and `array-default-strict', which act like the functions without "default" in the name when `array-strictness' is #t; otherwise they do nothing * Lots of small changes to existing array functions, mostly to ensure computations are done using nonstrict arrays, but return values are strict when `array-strictness' is #t * Added strictness tests * Added tests to ensure untyped code can use `math/array' * Rewrote `array-map' exported to untyped code using untyped Racket * Rearranged a lot of `math/array' documentation (cherry picked from commit 986e695)
1 parent 5c19a88 commit 54f029b

21 files changed

+1556
-808
lines changed

collects/math/array.rkt

-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
"private/array/array-transform.rkt"
99
"private/array/array-convert.rkt"
1010
"private/array/array-fold.rkt"
11-
"private/array/array-special-folds.rkt"
1211
"private/array/array-unfold.rkt"
1312
"private/array/array-print.rkt"
1413
"private/array/array-fft.rkt"
@@ -36,7 +35,6 @@
3635
"private/array/array-transform.rkt"
3736
"private/array/array-convert.rkt"
3837
"private/array/array-fold.rkt"
39-
"private/array/array-special-folds.rkt"
4038
"private/array/array-unfold.rkt"
4139
"private/array/array-print.rkt"
4240
"private/array/array-syntax.rkt"

collects/math/private/array/array-broadcast.rkt

+5-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@
4040

4141
(: array-broadcast (All (A) ((Array A) Indexes -> (Array A))))
4242
(define (array-broadcast arr ds)
43-
(if (equal? ds (array-shape arr)) arr (shift-stretch-axes arr ds)))
43+
(cond [(equal? ds (array-shape arr)) arr]
44+
[else (define new-arr (shift-stretch-axes arr ds))
45+
(if (or (array-strict? arr) ((array-size new-arr) . fx<= . (array-size arr)))
46+
new-arr
47+
(array-default-strict new-arr))]))
4448

4549
(: shape-insert-axes (Indexes Fixnum -> Indexes))
4650
(define (shape-insert-axes ds n)

collects/math/private/array/array-fft.rkt

+3-2
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@
5151
[(= k (- dims 1))
5252
(fcarray-last-axis-fft (array->fcarray arr))]
5353
[else
54-
(array-axis-swap (fcarray-last-axis-fft (array->fcarray (array-axis-swap arr k (- dims 1))))
55-
k (- dims 1))]))
54+
(parameterize ([array-strictness #f])
55+
(array-axis-swap (fcarray-last-axis-fft (array->fcarray (array-axis-swap arr k (- dims 1))))
56+
k (- dims 1)))]))
5657

5758
(: fcarray-fft (FCArray -> FCArray))
5859
(define (fcarray-fft arr)

collects/math/private/array/array-fold.rkt

+22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#lang racket/base
22

33
(require (for-syntax racket/base)
4+
(only-in typed/racket/base assert index?)
5+
"array-struct.rkt"
6+
"array-pointwise.rkt"
47
"typed-array-fold.rkt")
58

69
;; ===================================================================================================
@@ -28,6 +31,22 @@
2831
(define-all-fold array-all-min min)
2932
(define-all-fold array-all-max max)
3033

34+
(define-syntax-rule (array-count f arr ...)
35+
(assert
36+
(parameterize ([array-strictness #f])
37+
(array-all-sum (inline-array-map (λ (b) (if b 1 0))
38+
(array-map f arr ...))
39+
0))
40+
index?))
41+
42+
(define-syntax-rule (array-andmap pred? arr ...)
43+
(parameterize ([array-strictness #f])
44+
(array-all-and (array-map pred? arr ...))))
45+
46+
(define-syntax-rule (array-ormap pred? arr ...)
47+
(parameterize ([array-strictness #f])
48+
(array-all-or (array-map pred? arr ...))))
49+
3150
(provide array-axis-fold
3251
array-axis-sum
3352
array-axis-prod
@@ -44,6 +63,9 @@
4463
array-all-max
4564
array-all-and
4665
array-all-or
66+
array-count
67+
array-andmap
68+
array-ormap
4769
array-axis-reduce
4870
unsafe-array-axis-reduce
4971
array->list-array)

collects/math/private/array/array-special-folds.rkt

-27
This file was deleted.

collects/math/private/array/array-struct.rkt

+12-4
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
"array-syntax.rkt"
77
(except-in "typed-array-struct.rkt"
88
build-array
9-
build-strict-array
9+
build-simple-array
1010
list->array))
1111

1212
(require/untyped-contract
1313
(begin (require "typed-array-struct.rkt"))
1414
"typed-array-struct.rkt"
1515
[build-array (All (A) ((Vectorof Integer) ((Vectorof Index) -> A) -> (Array A)))]
16-
[build-strict-array (All (A) ((Vectorof Integer) ((Vectorof Index) -> A) -> (Array A)))]
16+
[build-simple-array (All (A) ((Vectorof Integer) ((Vectorof Index) -> A) -> (Array A)))]
1717
[list->array (All (A) (case-> ((Listof A) -> (Array A))
1818
((Vectorof Integer) (Listof A) -> (Array A))))])
1919

@@ -29,15 +29,18 @@
2929
array-shape
3030
array-dims
3131
array-size
32+
array-strictness
3233
array-strict
3334
array-strict!
35+
array-default-strict
36+
array-default-strict!
3437
array-strict?
3538
build-array
36-
build-strict-array
39+
build-simple-array
3740
list->array
3841
make-unsafe-array-proc
3942
unsafe-build-array
40-
unsafe-build-strict-array
43+
unsafe-build-simple-array
4144
unsafe-list->array
4245
unsafe-array-proc
4346
array-lazy
@@ -65,3 +68,8 @@
6568
(let ([arr arr-expr])
6669
(array-strict! arr)
6770
arr))
71+
72+
(define-syntax-rule (array-default-strict arr-expr)
73+
(let ([arr arr-expr])
74+
(array-default-strict! arr)
75+
arr))

collects/math/private/array/array-unfold.rkt

+4-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
(define (array-axis-expand arr k dk f)
3434
(let ([k (check-array-axis 'array-axis-expand arr k)])
3535
(cond [(not (index? dk)) (raise-argument-error 'array-axis-expand "Index" 2 arr k dk f)]
36-
[else (unsafe-array-axis-expand arr k dk f)])))
36+
[else (array-default-strict
37+
(unsafe-array-axis-expand arr k dk f))])))
3738

3839
;; ===================================================================================================
3940
;; Specific unfolds/expansions
@@ -46,6 +47,7 @@
4647
(let ([arr (array-strict (array-map (inst list->vector A) arr))])
4748
;(define dks (remove-duplicates (array->list (array-map vector-length arr))))
4849
(define dk (array-all-min (array-map vector-length arr)))
49-
(unsafe-array-axis-expand arr k dk (inst unsafe-vector-ref A)))]
50+
(array-default-strict
51+
(unsafe-array-axis-expand arr k dk (inst unsafe-vector-ref A))))]
5052
[else
5153
(raise-argument-error 'list-array->array (format "Index <= ~a" dims) 1 arr k)]))

collects/math/private/array/typed-array-constructors.rkt

+7-7
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,29 @@
1010
(define (make-array ds v)
1111
(let ([ds (check-array-shape
1212
ds (λ () (raise-argument-error 'make-array "(Vectorof Index)" 0 ds v)))])
13-
(unsafe-build-strict-array ds (λ (js) v))))
13+
(unsafe-build-simple-array ds (λ (js) v))))
1414

1515
(: axis-index-array (In-Indexes Integer -> (Array Index)))
1616
(define (axis-index-array ds k)
1717
(let* ([ds (check-array-shape
1818
ds (λ () (raise-argument-error 'axis-index-array "(Vectorof Index)" 0 ds k)))]
1919
[dims (vector-length ds)])
2020
(cond [(and (0 . <= . k) (k . < . dims))
21-
(unsafe-build-strict-array ds (λ: ([js : Indexes]) (unsafe-vector-ref js k)))]
21+
(unsafe-build-simple-array ds (λ: ([js : Indexes]) (unsafe-vector-ref js k)))]
2222
[else (raise-argument-error 'axis-index-array (format "Index < ~a" dims) 1 ds k)])))
2323

2424
(: index-array (In-Indexes -> (Array Index)))
2525
(define (index-array ds)
2626
(let ([ds (check-array-shape
2727
ds (λ () (raise-argument-error 'index-array "(Vectorof Index)" ds)))])
28-
(unsafe-build-strict-array ds (λ: ([js : Indexes])
28+
(unsafe-build-simple-array ds (λ: ([js : Indexes])
2929
(assert (unsafe-array-index->value-index ds js) index?)))))
3030

3131
(: indexes-array (In-Indexes -> (Array Indexes)))
3232
(define (indexes-array ds)
3333
(let ([ds (check-array-shape
3434
ds (λ () (raise-argument-error 'indexes-array "(Vectorof Index)" ds)))])
35-
(unsafe-build-strict-array ds (λ: ([js : Indexes]) (vector-copy-all js)))))
35+
(unsafe-build-simple-array ds (λ: ([js : Indexes]) (vector-copy-all js)))))
3636

3737
(: diagonal-array (All (A) (Integer Integer A A -> (Array A))))
3838
(define (diagonal-array dims size on-value off-value)
@@ -42,15 +42,15 @@
4242
(define: ds : Indexes (make-vector dims size))
4343
;; specialize for various cases
4444
(cond [(or (dims . <= . 1) (size . <= . 1))
45-
(unsafe-build-strict-array ds (λ: ([js : Indexes]) on-value))]
45+
(unsafe-build-simple-array ds (λ: ([js : Indexes]) on-value))]
4646
[(= dims 2)
47-
(unsafe-build-strict-array
47+
(unsafe-build-simple-array
4848
ds (λ: ([js : Indexes])
4949
(define j0 (unsafe-vector-ref js 0))
5050
(define j1 (unsafe-vector-ref js 1))
5151
(if (= j0 j1) on-value off-value)))]
5252
[else
53-
(unsafe-build-strict-array
53+
(unsafe-build-simple-array
5454
ds (λ: ([js : Indexes])
5555
(define j0 (unsafe-vector-ref js 0))
5656
(let: loop : A ([i : Nonnegative-Fixnum 1])

collects/math/private/array/typed-array-fold.rkt

+53-39
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,16 @@
3838
(: array-axis-reduce (All (A B) ((Array A) Integer (Index (Integer -> A) -> B) -> (Array B))))
3939
(define (array-axis-reduce arr k f)
4040
(let ([k (check-array-axis 'array-axis-reduce arr k)])
41-
(unsafe-array-axis-reduce
42-
arr k
43-
(λ: ([dk : Index] [proc : (Index -> A)])
44-
(: safe-proc (Integer -> A))
45-
(define (safe-proc jk)
46-
(cond [(or (jk . < . 0) (jk . >= . dk))
47-
(raise-argument-error 'array-axis-reduce (format "Index < ~a" dk) jk)]
48-
[else (proc jk)]))
49-
(f dk safe-proc)))))
41+
(array-default-strict
42+
(unsafe-array-axis-reduce
43+
arr k
44+
(λ: ([dk : Index] [proc : (Index -> A)])
45+
(: safe-proc (Integer -> A))
46+
(define (safe-proc jk)
47+
(cond [(or (jk . < . 0) (jk . >= . dk))
48+
(raise-argument-error 'array-axis-reduce (format "Index < ~a" dk) jk)]
49+
[else (proc jk)]))
50+
(f dk safe-proc))))))
5051

5152
(: array-axis-fold/init (All (A B) ((Array A) Integer (A B -> B) B -> (Array B))))
5253
(define (array-axis-fold/init arr k f init)
@@ -72,8 +73,8 @@
7273
((Array A) Integer (A B -> B) B -> (Array B)))))
7374
(define array-axis-fold
7475
(case-lambda
75-
[(arr k f) (array-axis-fold/no-init arr k f)]
76-
[(arr k f init) (array-axis-fold/init arr k f init)]))
76+
[(arr k f) (array-default-strict (array-axis-fold/no-init arr k f))]
77+
[(arr k f init) (array-default-strict (array-axis-fold/init arr k f init))]))
7778

7879
;; ===================================================================================================
7980
;; Whole-array folds
@@ -93,13 +94,18 @@
9394
(define array-all-fold
9495
(case-lambda
9596
[(arr f)
96-
(array-ref (array-fold arr (λ: ([arr : (Array A)] [k : Index])
97-
(array-axis-fold arr k f)))
98-
#())]
97+
;; Though `f' is folded over multiple axes, each element of `arr' is referred to only once, so
98+
;; turning strictness off can't hurt performance
99+
(parameterize ([array-strictness #f])
100+
(array-ref (array-fold arr (λ: ([arr : (Array A)] [k : Index])
101+
(array-axis-fold arr k f)))
102+
#()))]
99103
[(arr f init)
100-
(array-ref (array-fold arr (λ: ([arr : (Array A)] [k : Index])
101-
(array-axis-fold arr k f init)))
102-
#())]))
104+
;; See above for why non-strictness is okay
105+
(parameterize ([array-strictness #f])
106+
(array-ref (array-fold arr (λ: ([arr : (Array A)] [k : Index])
107+
(array-axis-fold arr k f init)))
108+
#()))]))
103109

104110
) ; begin-encourage-inline
105111

@@ -109,44 +115,51 @@
109115
(: array-axis-count (All (A) ((Array A) Integer (A -> Any) -> (Array Index))))
110116
(define (array-axis-count arr k pred?)
111117
(let ([k (check-array-axis 'array-axis-count arr k)])
112-
(unsafe-array-axis-reduce
113-
arr k (λ: ([dk : Index] [proc : (Index -> A)])
114-
(let: loop : Index ([jk : Nonnegative-Fixnum 0] [acc : Nonnegative-Fixnum 0])
115-
(if (jk . fx< . dk)
116-
(cond [(pred? (proc jk)) (loop (fx+ jk 1) (unsafe-fx+ acc 1))]
117-
[else (loop (fx+ jk 1) acc)])
118-
(assert acc index?)))))))
118+
(array-default-strict
119+
(unsafe-array-axis-reduce
120+
arr k (λ: ([dk : Index] [proc : (Index -> A)])
121+
(let: loop : Index ([jk : Nonnegative-Fixnum 0] [acc : Nonnegative-Fixnum 0])
122+
(if (jk . fx< . dk)
123+
(cond [(pred? (proc jk)) (loop (fx+ jk 1) (unsafe-fx+ acc 1))]
124+
[else (loop (fx+ jk 1) acc)])
125+
(assert acc index?))))))))
119126

120127
;; ===================================================================================================
121128
;; Short-cutting axis folds
122129

123130
(: array-axis-and (All (A) ((Array A) Integer -> (Array (U A Boolean)))))
124131
(define (array-axis-and arr k)
125132
(let ([k (check-array-axis 'array-axis-and arr k)])
126-
(unsafe-array-axis-reduce
127-
arr k (λ: ([dk : Index] [proc : (Index -> A)])
128-
(let: loop : (U A Boolean) ([jk : Nonnegative-Fixnum 0] [acc : (U A Boolean) #t])
129-
(cond [(jk . fx< . dk) (define v (and acc (proc jk)))
130-
(if v (loop (fx+ jk 1) v) v)]
131-
[else acc]))))))
133+
(array-default-strict
134+
(unsafe-array-axis-reduce
135+
arr k (λ: ([dk : Index] [proc : (Index -> A)])
136+
(let: loop : (U A Boolean) ([jk : Nonnegative-Fixnum 0] [acc : (U A Boolean) #t])
137+
(cond [(jk . fx< . dk) (define v (and acc (proc jk)))
138+
(if v (loop (fx+ jk 1) v) v)]
139+
[else acc])))))))
132140

133141
(: array-axis-or (All (A) ((Array A) Integer -> (Array (U A #f)))))
134142
(define (array-axis-or arr k)
135143
(let ([k (check-array-axis 'array-axis-or arr k)])
136-
(unsafe-array-axis-reduce
137-
arr k (λ: ([dk : Index] [proc : (Index -> A)])
138-
(let: loop : (U A #f) ([jk : Nonnegative-Fixnum 0] [acc : (U A #f) #f])
139-
(cond [(jk . fx< . dk) (define v (or acc (proc jk)))
140-
(if v v (loop (fx+ jk 1) v))]
141-
[else acc]))))))
144+
(array-default-strict
145+
(unsafe-array-axis-reduce
146+
arr k (λ: ([dk : Index] [proc : (Index -> A)])
147+
(let: loop : (U A #f) ([jk : Nonnegative-Fixnum 0] [acc : (U A #f) #f])
148+
(cond [(jk . fx< . dk) (define v (or acc (proc jk)))
149+
(if v v (loop (fx+ jk 1) v))]
150+
[else acc])))))))
142151

143152
(: array-all-and (All (A B) ((Array A) -> (U A Boolean))))
144153
(define (array-all-and arr)
145-
(array-ref ((inst array-fold (U A Boolean)) arr array-axis-and) #()))
154+
;; See `array-all-fold' for why non-strictness is okay
155+
(parameterize ([array-strictness #f])
156+
(array-ref ((inst array-fold (U A Boolean)) arr array-axis-and) #())))
146157

147158
(: array-all-or (All (A B) ((Array A) -> (U A #f))))
148159
(define (array-all-or arr)
149-
(array-ref ((inst array-fold (U A #f)) arr array-axis-or) #()))
160+
;; See `array-all-fold' for why non-strictness is okay
161+
(parameterize ([array-strictness #f])
162+
(array-ref ((inst array-fold (U A #f)) arr array-axis-or) #())))
150163

151164
;; ===================================================================================================
152165
;; Other folds
@@ -156,6 +169,7 @@
156169
(define (array->list-array arr [k 0])
157170
(define dims (array-dims arr))
158171
(cond [(and (k . >= . 0) (k . < . dims))
159-
(unsafe-array-axis-reduce arr k (inst build-list A))]
172+
(array-default-strict
173+
(unsafe-array-axis-reduce arr k (inst build-list A)))]
160174
[else
161175
(raise-argument-error 'array->list-array (format "Index < ~a" dims) 1 arr k)]))

0 commit comments

Comments
 (0)