@@ -614,17 +614,28 @@ CArray_Vdot(CArray * a, CArray * b, MemoryPointer * out)
614
614
CArray * *
615
615
CArray_Svd (CArray * a , int full_matrices , int compute_uv , MemoryPointer * out )
616
616
{
617
- int m , n ;
617
+ int m , n , casted = 0 ;
618
618
int lda , ldu , ldvt , info , lwork ;
619
619
int * iwork ;
620
620
double * s , * u , * vt ;
621
- CArray * u_ca , * s_ca , * vh_ca , * * rtn ;
621
+ CArray * u_ca , * s_ca , * vh_ca , * * rtn , * target ;
622
622
623
623
if (CArray_NDIM (a ) != 2 ) {
624
624
throw_valueerror_exception ("Expected 2D array" );
625
625
return NULL ;
626
626
}
627
627
628
+ if (CArray_DESCR (a )-> type_num != TYPE_DOUBLE_INT ) {
629
+ CArrayDescriptor * descr = CArray_DescrFromType (TYPE_DOUBLE_INT );
630
+ target = CArray_NewLikeArray (a , CARRAY_CORDER , descr , 0 );
631
+ if (CArray_CastTo (target , a ) < 0 ) {
632
+ return NULL ;
633
+ }
634
+ casted = 1 ;
635
+ } else {
636
+ target = a ;
637
+ }
638
+
628
639
m = CArray_DIMS (a )[0 ];
629
640
n = CArray_DIMS (a )[1 ];
630
641
@@ -637,33 +648,33 @@ CArray_Svd(CArray * a, int full_matrices, int compute_uv, MemoryPointer * out)
637
648
u = emalloc (sizeof (double ) * (ldu * m ));
638
649
vt = emalloc (sizeof (double ) * (ldvt * n ));
639
650
640
- info = LAPACKE_dgesdd (LAPACK_ROW_MAJOR , 'S' , m , n , DDATA (a ), lda , s , u , ldu , vt , ldvt );
651
+ info = LAPACKE_dgesdd (LAPACK_ROW_MAJOR , 'S' , m , n , DDATA (target ), lda , s , u , ldu , vt , ldvt );
641
652
642
653
if ( info > 0 ) {
643
654
throw_valueerror_exception ( "The algorithm computing SVD failed to converge." );
644
655
return NULL ;
645
656
}
646
657
647
658
u_ca = emalloc (sizeof (CArray ));
648
- u_ca = CArray_NewFromDescr_int (u_ca , CArray_DESCR (a ), CArray_NDIM (a ), CArray_DIMS (a ), CArray_STRIDES (a ),
659
+ u_ca = CArray_NewFromDescr_int (u_ca , CArray_DESCR (target ), CArray_NDIM (target ), CArray_DIMS (target ), CArray_STRIDES (target ),
649
660
NULL , 0 , NULL , 1 , 0 );
650
661
efree (u_ca -> data );
651
662
u_ca -> data = (char * )u ;
652
- CArrayDescriptor_INCREF (CArray_DESCR (a ));
663
+ CArrayDescriptor_INCREF (CArray_DESCR (target ));
653
664
654
665
s_ca = emalloc (sizeof (CArray ));
655
- s_ca = CArray_NewFromDescr_int (s_ca , CArray_DESCR (a ), 1 , & n , NULL ,
666
+ s_ca = CArray_NewFromDescr_int (s_ca , CArray_DESCR (target ), 1 , & n , NULL ,
656
667
NULL , 0 , NULL , 1 , 0 );
657
668
efree (s_ca -> data );
658
669
s_ca -> data = (char * )s ;
659
- CArrayDescriptor_INCREF (CArray_DESCR (a ));
670
+ CArrayDescriptor_INCREF (CArray_DESCR (target ));
660
671
661
672
vh_ca = emalloc (sizeof (CArray ));
662
- vh_ca = CArray_NewFromDescr_int (vh_ca , CArray_DESCR (a ), CArray_NDIM (a ), CArray_DIMS (a ), CArray_STRIDES (a ),
673
+ vh_ca = CArray_NewFromDescr_int (vh_ca , CArray_DESCR (target ), CArray_NDIM (target ), CArray_DIMS (target ), CArray_STRIDES (target ),
663
674
NULL , 0 , NULL , 1 , 0 );
664
675
efree (vh_ca -> data );
665
676
vh_ca -> data = (char * )vt ;
666
- CArrayDescriptor_INCREF (CArray_DESCR (a ));
677
+ CArrayDescriptor_INCREF (CArray_DESCR (target ));
667
678
668
679
if (out != NULL ) {
669
680
add_to_buffer (& (out [0 ]), u_ca , sizeof (CArray ));
@@ -676,5 +687,10 @@ CArray_Svd(CArray * a, int full_matrices, int compute_uv, MemoryPointer * out)
676
687
rtn [1 ] = s_ca ;
677
688
rtn [2 ] = vh_ca ;
678
689
690
+ if (casted ) {
691
+ CArrayDescriptor_DECREF (CArray_DESCR (target ));
692
+ CArray_Free (target );
693
+ }
694
+
679
695
return rtn ;
680
696
}
0 commit comments