Skip to content

Commit 2f8794e

Browse files
authored
Cuda refactoring: (#10195)
Integrates device based shape buffers for direct shape tries now Update cmakelists to ignore helpers/cpu when compiling cuda Remove aurora checks from nd4j backend
1 parent 7952478 commit 2f8794e

File tree

23 files changed

+629
-95
lines changed

23 files changed

+629
-95
lines changed

libnd4j/CMakeLists.txt.in

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ cmake_minimum_required(VERSION 2.8.2)
22

33
project(flatbuffers-download NONE)
44

5-
# Force clean everything first
6-
file(REMOVE_RECURSE ${CMAKE_CURRENT_BINARY_DIR}/*)
75

86
include(ExternalProject)
97
ExternalProject_Add(flatbuffers
@@ -16,7 +14,8 @@ ExternalProject_Add(flatbuffers
1614
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
1715
-DFLATBUFFERS_BUILD_FLATC=ON
1816
-DCMAKE_BUILD_TYPE=Release
19-
DOWNLOAD_COMMAND ${CMAKE_COMMAND} -E rm -rf ${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-src && git clone https://github.com/google/flatbuffers/ ${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-src
17+
DOWNLOAD_COMMAND ${CMAKE_COMMAND} -E rm -rf ${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-src COMMAND git clone https://github.com/google/flatbuffers/ ${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-src
18+
2019
UPDATE_COMMAND ""
2120
INSTALL_COMMAND ""
2221
TEST_COMMAND ""

libnd4j/blas/CMakeLists.txt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,16 @@ if(SD_CUDA)
356356
file(GLOB_RECURSE CUSTOMOPS_SOURCES ../include/ops/declarable/generic/*.cpp)
357357
file(GLOB_RECURSE CUSTOMOPS_HELPERS_SOURCES ../include/ops/declarable/helpers/cuda/*.cu ../include/ops/declarable/helpers/impl/*.cpp)
358358
file(GLOB_RECURSE OPS_SOURCES ../include/ops/impl/*.cpp ../include/ops/declarable/impl/*.cpp ../include/ops/*.h)
359-
file(GLOB_RECURSE HELPERS_SOURCES ../include/build_info.cpp ../include/ConstMessages.cpp ../include/helpers/*.cpp ../include/helpers/cuda/*.cu ../include/helpers/*.h)
359+
file(GLOB_RECURSE HELPERS_SOURCES
360+
../include/build_info.cpp
361+
../include/ConstMessages.cpp
362+
../include/helpers/*.cpp
363+
../include/helpers/cuda/*.cu
364+
../include/helpers/*.h)
365+
file(GLOB CPU_HELPERS_TO_EXCLUDE
366+
../include/helpers/cpu/*.cpp)
367+
# remove helpers/cpu
368+
list(REMOVE_ITEM HELPERS_SOURCES ${CPU_HELPERS_TO_EXCLUDE})
360369
file(GLOB_RECURSE INDEXING_SOURCES ../include/indexing/*.cpp ../include/indexing/*.h)
361370
file(GLOB_RECURSE LOOPS_SOURCES ../include/loops/impl/*.cpp ../include/loops/*.h)
362371
file(GLOB_RECURSE LEGACY_SOURCES ../include/legacy/impl/*.cpp ../include/legacy/*.cu ../include/legacy/*.h)

libnd4j/include/helpers/DirectShapeTrie.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
#include <array/ConstantShapeBuffer.h>
55
#include <system/common.h>
6-
6+
#include "ShapeBufferPlatformHelper.h"
77
#include <array>
88
#include <atomic>
99
#include <memory>
@@ -78,6 +78,8 @@ class SD_LIB_EXPORT DirectShapeTrie {
7878
// Make sure mutexes are properly initialized
7979
new (&_mutexes[i]) SHAPE_MUTEX_TYPE(); // Explicit initialization
8080
}
81+
82+
ShapeBufferPlatformHelper::initialize();
8183
}
8284

8385
// Delete copy constructor and assignment
@@ -88,9 +90,6 @@ class SD_LIB_EXPORT DirectShapeTrie {
8890
DirectShapeTrie(DirectShapeTrie&&) = delete;
8991
DirectShapeTrie& operator=(DirectShapeTrie&&) = delete;
9092

91-
// Create a shape buffer from shapeInfo
92-
ConstantShapeBuffer* createBuffer(const LongType* shapeInfo);
93-
9493
// Improved thread-safe getOrCreate
9594
ConstantShapeBuffer* getOrCreate(const LongType* shapeInfo);
9695

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* ******************************************************************************
3+
* *
4+
* *
5+
* * This program and the accompanying materials are made available under the
6+
* * terms of the Apache License, Version 2.0 which is available at
7+
* * https://www.apache.org/licenses/LICENSE-2.0.
8+
* *
9+
* * See the NOTICE file distributed with this work for additional
10+
* * information regarding copyright ownership.
11+
* * Unless required by applicable law or agreed to in writing, software
12+
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
* * License for the specific language governing permissions and limitations
15+
* * under the License.
16+
* *
17+
* * SPDX-License-Identifier: Apache-2.0
18+
* *****************************************************************************
19+
*/
20+
21+
#ifndef LIBND4J_SHAPEBUFFERCREATOR_H
22+
#define LIBND4J_SHAPEBUFFERCREATOR_H
23+
24+
#include "array/ConstantShapeBuffer.h"
25+
26+
namespace sd {
27+
28+
/**
29+
* Interface for creating ConstantShapeBuffer objects.
30+
* This allows for platform-specific implementations (CPU, CUDA, etc).
31+
*/
32+
class ShapeBufferCreator {
33+
public:
34+
virtual ~ShapeBufferCreator() = default;
35+
36+
/**
37+
* Create a ConstantShapeBuffer from the given shape information
38+
*
39+
* @param shapeInfo Pointer to shape information
40+
* @param rank Rank of the shape
41+
* @return A new ConstantShapeBuffer instance
42+
*/
43+
virtual ConstantShapeBuffer* create(const LongType* shapeInfo, int rank) = 0;
44+
};
45+
46+
} // namespace sd
47+
48+
#endif // LIBND4J_SHAPEBUFFERCREATOR_H
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* ******************************************************************************
3+
* *
4+
* *
5+
* * This program and the accompanying materials are made available under the
6+
* * terms of the Apache License, Version 2.0 which is available at
7+
* * https://www.apache.org/licenses/LICENSE-2.0.
8+
* *
9+
* * See the NOTICE file distributed with this work for additional
10+
* * information regarding copyright ownership.
11+
* * Unless required by applicable law or agreed to in writing, software
12+
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
* * License for the specific language governing permissions and limitations
15+
* * under the License.
16+
* *
17+
* * SPDX-License-Identifier: Apache-2.0
18+
* *****************************************************************************
19+
*/
20+
21+
#ifndef LIBND4J_SHAPEBUFFERCREATORHELPER_H
22+
#define LIBND4J_SHAPEBUFFERCREATORHELPER_H
23+
24+
#include <helpers/ShapeBufferCreator.h>
25+
#include <exception>
26+
27+
namespace sd {
28+
29+
/**
30+
* Helper class to manage ShapeBufferCreator instances and provide global access
31+
*/
32+
class ShapeBufferCreatorHelper {
33+
public:
34+
/**
35+
* Get the current ShapeBufferCreator instance
36+
*/
37+
static ShapeBufferCreator& getCurrentCreator();
38+
39+
/**
40+
* Set the current ShapeBufferCreator to use
41+
*/
42+
static void setCurrentCreator(ShapeBufferCreator* creator);
43+
44+
private:
45+
static ShapeBufferCreator* currentCreator_;
46+
};
47+
48+
} // namespace sd
49+
50+
#endif // LIBND4J_SHAPEBUFFERCREATORHELPER_H
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* ******************************************************************************
3+
* *
4+
* *
5+
* * This program and the accompanying materials are made available under the
6+
* * terms of the Apache License, Version 2.0 which is available at
7+
* * https://www.apache.org/licenses/LICENSE-2.0.
8+
* *
9+
* * See the NOTICE file distributed with this work for additional
10+
* * information regarding copyright ownership.
11+
* * Unless required by applicable law or agreed to in writing, software
12+
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
* * License for the specific language governing permissions and limitations
15+
* * under the License.
16+
* *
17+
* * SPDX-License-Identifier: Apache-2.0
18+
* *****************************************************************************
19+
*/
20+
21+
#ifndef LIBND4J_SHAPEBUFFERPLATFORMHELPER_H
22+
#define LIBND4J_SHAPEBUFFERPLATFORMHELPER_H
23+
24+
#include <helpers/ShapeBufferCreatorHelper.h>
25+
26+
namespace sd {
27+
28+
/**
29+
* Platform-specific initialization helper
30+
* Takes care of setting up the correct creators based on the available hardware
31+
*/
32+
class ShapeBufferPlatformHelper {
33+
public:
34+
/**
35+
* Initialize platform-specific components
36+
* This method should be called during the application startup
37+
* to ensure proper creators are set based on the available hardware
38+
*/
39+
static void initialize();
40+
41+
/**
42+
* Automatic initialization through static initialization
43+
* C++17 guarantees thread-safety for static initialization
44+
*/
45+
static inline const bool initialized = (initialize(), true);
46+
47+
private:
48+
ShapeBufferPlatformHelper() = delete; // Prevent instantiation
49+
};
50+
51+
} // namespace sd
52+
53+
#endif // LIBND4J_SHAPEBUFFERPLATFORMHELPER_H
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* ******************************************************************************
3+
* *
4+
* *
5+
* * This program and the accompanying materials are made available under the
6+
* * terms of the Apache License, Version 2.0 which is available at
7+
* * https://www.apache.org/licenses/LICENSE-2.0.
8+
* *
9+
* * See the NOTICE file distributed with this work for additional
10+
* * information regarding copyright ownership.
11+
* * Unless required by applicable law or agreed to in writing, software
12+
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
* * License for the specific language governing permissions and limitations
15+
* * under the License.
16+
* *
17+
* * SPDX-License-Identifier: Apache-2.0
18+
* *****************************************************************************
19+
*/
20+
21+
#include <helpers/cpu/CpuShapeBufferCreator.h>
22+
23+
24+
namespace sd {
25+
26+
ConstantShapeBuffer* CpuShapeBufferCreator::create(const LongType* shapeInfo, int rank) {
27+
const int shapeInfoLength = shape::shapeInfoLength(rank);
28+
LongType* shapeCopy = new LongType[shapeInfoLength];
29+
std::memcpy(shapeCopy, shapeInfo, shapeInfoLength * sizeof(LongType));
30+
31+
auto deallocator = std::shared_ptr<PrimaryPointerDeallocator>(
32+
new PrimaryPointerDeallocator(),
33+
[] (PrimaryPointerDeallocator* ptr) { delete ptr; }
34+
);
35+
36+
auto hPtr = std::make_shared<PointerWrapper>(shapeCopy, deallocator);
37+
auto buffer = new ConstantShapeBuffer(hPtr);
38+
39+
return buffer;
40+
}
41+
42+
CpuShapeBufferCreator& CpuShapeBufferCreator::getInstance() {
43+
static CpuShapeBufferCreator instance;
44+
return instance;
45+
}
46+
47+
} // namespace sd
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* ******************************************************************************
3+
* *
4+
* *
5+
* * This program and the accompanying materials are made available under the
6+
* * terms of the Apache License, Version 2.0 which is available at
7+
* * https://www.apache.org/licenses/LICENSE-2.0.
8+
* *
9+
* * See the NOTICE file distributed with this work for additional
10+
* * information regarding copyright ownership.
11+
* * Unless required by applicable law or agreed to in writing, software
12+
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
* * License for the specific language governing permissions and limitations
15+
* * under the License.
16+
* *
17+
* * SPDX-License-Identifier: Apache-2.0
18+
* *****************************************************************************
19+
*/
20+
21+
#ifndef LIBND4J_CPUSHAPEBUFFERCREATOR_H
22+
#define LIBND4J_CPUSHAPEBUFFERCREATOR_H
23+
24+
#include <helpers/ShapeBufferCreator.h>
25+
#include <helpers/shape.h>
26+
#include <memory>
27+
28+
namespace sd {
29+
30+
/**
31+
* CPU implementation of the ShapeBufferCreator.
32+
*/
33+
class CpuShapeBufferCreator : public ShapeBufferCreator {
34+
public:
35+
/**
36+
* Create a ConstantShapeBuffer for CPU usage
37+
*/
38+
ConstantShapeBuffer* create(const LongType* shapeInfo, int rank) override;
39+
40+
// Singleton pattern implementation
41+
static CpuShapeBufferCreator& getInstance();
42+
43+
private:
44+
CpuShapeBufferCreator() = default;
45+
};
46+
47+
} // namespace sd
48+
49+
#endif // LIBND4J_CPUSHAPEBUFFERCREATOR_H
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* ******************************************************************************
3+
* *
4+
* *
5+
* * This program and the accompanying materials are made available under the
6+
* * terms of the Apache License, Version 2.0 which is available at
7+
* * https://www.apache.org/licenses/LICENSE-2.0.
8+
* *
9+
* * See the NOTICE file distributed with this work for additional
10+
* * information regarding copyright ownership.
11+
* * Unless required by applicable law or agreed to in writing, software
12+
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
* * License for the specific language governing permissions and limitations
15+
* * under the License.
16+
* *
17+
* * SPDX-License-Identifier: Apache-2.0
18+
* *****************************************************************************
19+
*/
20+
21+
#include <helpers/ConstantHelper.h>
22+
#include <helpers/cuda/CudaShapeBufferCreator.h>
23+
24+
#include "array/CudaPointerDeallocator.h"
25+
#include "array/PrimaryPointerDeallocator.h"
26+
27+
namespace sd {
28+
29+
ConstantShapeBuffer* CudaShapeBufferCreator::create(const LongType* shapeInfo, int rank) {
30+
const int shapeInfoLength = shape::shapeInfoLength(rank);
31+
LongType* shapeCopy = new LongType[shapeInfoLength];
32+
std::memcpy(shapeCopy, shapeInfo, shapeInfoLength * sizeof(LongType));
33+
34+
auto deallocator = std::shared_ptr<PrimaryPointerDeallocator>(
35+
new PrimaryPointerDeallocator(),
36+
[] (PrimaryPointerDeallocator* ptr) { delete ptr; }
37+
);
38+
39+
auto hPtr = std::make_shared<PointerWrapper>(shapeCopy, deallocator);
40+
41+
// Create device pointer for CUDA
42+
auto dPtr = std::make_shared<PointerWrapper>(
43+
ConstantHelper::getInstance().replicatePointer(hPtr->pointer(),
44+
shape::shapeInfoByteLength(hPtr->pointerAsT<sd::LongType>())),
45+
std::make_shared<CudaPointerDeallocator>());
46+
47+
ConstantShapeBuffer *buffer = new ConstantShapeBuffer(hPtr, dPtr);
48+
49+
return buffer;
50+
}
51+
52+
CudaShapeBufferCreator& CudaShapeBufferCreator::getInstance() {
53+
static CudaShapeBufferCreator instance;
54+
return instance;
55+
}
56+
57+
} // namespace sd

0 commit comments

Comments
 (0)