Skip to content
This repository was archived by the owner on Feb 18, 2020. It is now read-only.

Commit c1212d6

Browse files
CArray::svd casting
1 parent 8ab5977 commit c1212d6

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

kernel/linalg.c

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -614,17 +614,28 @@ CArray_Vdot(CArray * a, CArray * b, MemoryPointer * out)
614614
CArray **
615615
CArray_Svd(CArray * a, int full_matrices, int compute_uv, MemoryPointer * out)
616616
{
617-
int m, n;
617+
int m, n, casted = 0;
618618
int lda, ldu, ldvt, info, lwork;
619619
int * iwork;
620620
double * s, * u, * vt;
621-
CArray * u_ca, * s_ca, *vh_ca, ** rtn;
621+
CArray * u_ca, * s_ca, *vh_ca, ** rtn, * target;
622622

623623
if (CArray_NDIM(a) != 2) {
624624
throw_valueerror_exception("Expected 2D array");
625625
return NULL;
626626
}
627627

628+
if (CArray_DESCR(a)->type_num != TYPE_DOUBLE_INT) {
629+
CArrayDescriptor *descr = CArray_DescrFromType(TYPE_DOUBLE_INT);
630+
target = CArray_NewLikeArray(a, CARRAY_CORDER, descr, 0);
631+
if(CArray_CastTo(target, a) < 0) {
632+
return NULL;
633+
}
634+
casted = 1;
635+
} else {
636+
target = a;
637+
}
638+
628639
m = CArray_DIMS(a)[0];
629640
n = CArray_DIMS(a)[1];
630641

@@ -637,33 +648,33 @@ CArray_Svd(CArray * a, int full_matrices, int compute_uv, MemoryPointer * out)
637648
u = emalloc(sizeof(double) * (ldu * m));
638649
vt = emalloc(sizeof(double) * (ldvt * n));
639650

640-
info = LAPACKE_dgesdd(LAPACK_ROW_MAJOR, 'S', m, n, DDATA(a), lda, s, u, ldu, vt, ldvt);
651+
info = LAPACKE_dgesdd(LAPACK_ROW_MAJOR, 'S', m, n, DDATA(target), lda, s, u, ldu, vt, ldvt);
641652

642653
if( info > 0 ) {
643654
throw_valueerror_exception( "The algorithm computing SVD failed to converge." );
644655
return NULL;
645656
}
646657

647658
u_ca = emalloc(sizeof(CArray));
648-
u_ca = CArray_NewFromDescr_int(u_ca, CArray_DESCR(a), CArray_NDIM(a), CArray_DIMS(a), CArray_STRIDES(a),
659+
u_ca = CArray_NewFromDescr_int(u_ca, CArray_DESCR(target), CArray_NDIM(target), CArray_DIMS(target), CArray_STRIDES(target),
649660
NULL, 0, NULL, 1, 0);
650661
efree(u_ca->data);
651662
u_ca->data = (char *)u;
652-
CArrayDescriptor_INCREF(CArray_DESCR(a));
663+
CArrayDescriptor_INCREF(CArray_DESCR(target));
653664

654665
s_ca = emalloc(sizeof(CArray));
655-
s_ca = CArray_NewFromDescr_int(s_ca, CArray_DESCR(a), 1, &n, NULL,
666+
s_ca = CArray_NewFromDescr_int(s_ca, CArray_DESCR(target), 1, &n, NULL,
656667
NULL, 0, NULL, 1, 0);
657668
efree(s_ca->data);
658669
s_ca->data = (char *)s;
659-
CArrayDescriptor_INCREF(CArray_DESCR(a));
670+
CArrayDescriptor_INCREF(CArray_DESCR(target));
660671

661672
vh_ca = emalloc(sizeof(CArray));
662-
vh_ca = CArray_NewFromDescr_int(vh_ca, CArray_DESCR(a), CArray_NDIM(a), CArray_DIMS(a), CArray_STRIDES(a),
673+
vh_ca = CArray_NewFromDescr_int(vh_ca, CArray_DESCR(target), CArray_NDIM(target), CArray_DIMS(target), CArray_STRIDES(target),
663674
NULL, 0, NULL, 1, 0);
664675
efree(vh_ca->data);
665676
vh_ca->data = (char *)vt;
666-
CArrayDescriptor_INCREF(CArray_DESCR(a));
677+
CArrayDescriptor_INCREF(CArray_DESCR(target));
667678

668679
if (out != NULL) {
669680
add_to_buffer(&(out[0]), u_ca, sizeof(CArray));
@@ -676,5 +687,10 @@ CArray_Svd(CArray * a, int full_matrices, int compute_uv, MemoryPointer * out)
676687
rtn[1] = s_ca;
677688
rtn[2] = vh_ca;
678689

690+
if (casted) {
691+
CArrayDescriptor_DECREF(CArray_DESCR(target));
692+
CArray_Free(target);
693+
}
694+
679695
return rtn;
680696
}

0 commit comments

Comments
 (0)