Skip to content

Commit 4282bb7

Browse files
authored
Fix cuda op executioner exec (#10199)
* Fix op executioner reshape FIx cpu build with the new shape buffer creator Fix kotlin compiler in intellij * remove the toString() call
1 parent 705f913 commit 4282bb7

File tree

11 files changed

+46
-245
lines changed

11 files changed

+46
-245
lines changed

libnd4j/include/array/PointerDeallocator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace sd {
3232
class SD_LIB_EXPORT PointerDeallocator {
3333
public:
3434
PointerDeallocator() = default;
35-
~PointerDeallocator() = default;
35+
virtual ~PointerDeallocator() = default;
3636

3737
virtual void release(void *ptr);
3838
};

libnd4j/include/helpers/cpu/CpuShapeBufferCreator.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ ConstantShapeBuffer* CpuShapeBufferCreator::create(const LongType* shapeInfo, in
2828
LongType* shapeCopy = new LongType[shapeInfoLength];
2929
std::memcpy(shapeCopy, shapeInfo, shapeInfoLength * sizeof(LongType));
3030

31-
auto deallocator = std::shared_ptr<PrimaryPointerDeallocator>(
32-
new PrimaryPointerDeallocator(),
33-
[] (PrimaryPointerDeallocator* ptr) { delete ptr; }
31+
auto deallocator = std::shared_ptr<PointerDeallocator>(
32+
new PointerDeallocator(),
33+
[] (PointerDeallocator* ptr) { delete ptr; }
3434
);
3535

3636
auto hPtr = std::make_shared<PointerWrapper>(shapeCopy, deallocator);

libnd4j/include/legacy/cpu/NativeOps.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,19 @@ using namespace sd;
7070

7171

7272
//these are mainly for cuda
73-
sd::Pointer lcScalarPointer(OpaqueLaunchContext *lc) { return nullptr; }
73+
sd::Pointer lcScalarPointer(OpaqueLaunchContext lc) { return nullptr; }
7474

75-
sd::Pointer lcReductionPointer(OpaqueLaunchContext *lc) { return nullptr; }
75+
sd::Pointer lcReductionPointer(OpaqueLaunchContext lc) { return nullptr; }
7676

77-
sd::Pointer lcAllocationPointer(OpaqueLaunchContext *lc) { return nullptr; }
77+
sd::Pointer lcAllocationPointer(OpaqueLaunchContext lc) { return nullptr; }
7878

79-
sd::Pointer lcExecutionStream(OpaqueLaunchContext *lc) { return nullptr; }
79+
sd::Pointer lcExecutionStream(OpaqueLaunchContext lc) { return nullptr; }
8080

81-
sd::Pointer lcCopyStream(OpaqueLaunchContext *lc) { return nullptr; }
81+
sd::Pointer lcCopyStream(OpaqueLaunchContext lc) { return nullptr; }
8282

83-
sd::Pointer lcBlasHandle(OpaqueLaunchContext *lc) { return nullptr; }
83+
sd::Pointer lcBlasHandle(OpaqueLaunchContext lc) { return nullptr; }
8484

85-
sd::Pointer lcSolverHandle(OpaqueLaunchContext *lc) { return nullptr; }
85+
sd::Pointer lcSolverHandle(OpaqueLaunchContext lc) { return nullptr; }
8686

8787

8888
void execBroadcastBool(Pointer *extraPointers, int opNum, NDArray *x, NDArray *y,

libnd4j/include/loops/cuda/specials/tileKernel.cu

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,16 +175,22 @@
175175

176176
if (outOrder == 'c') {
177177
for (LongType i = tid; i < resultLength; i += totalThreads) {
178-
// We do direct linear offset for output as i
179-
// The offset in the input is determined by the out-coords
180-
// mapped to the input stride
181-
sd::LongType coords[SD_MAX_RANK];
178+
sd::LongType outCoords[SD_MAX_RANK];
179+
sd::LongType inCoords[SD_MAX_RANK];
182180
sd::LongType inOffset;
183181

184-
INDEX2COORDS(i, outRank, outShapePtr, coords);
185-
COORDS2INDEX(outRank, inStridePtr, coords, inOffset);
182+
// Get output coordinates
183+
INDEX2COORDS(i, outRank, outShapePtr, outCoords);
184+
185+
// Map to input coordinates (using modulo for tiling)
186+
for (int d = 0; d < inRank; d++) {
187+
inCoords[d] = outCoords[d] % inShapePtr[d];
188+
}
186189

187-
outData[i] = static_cast<X>(inData[inOffset]);
190+
// Get input offset from input coordinates
191+
COORDS2INDEX(inRank, inStridePtr, inCoords, inOffset);
192+
193+
outData[i] = inData[inOffset];
188194
}
189195
}
190196
else {

nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -868,26 +868,6 @@ public void commit() {
868868
// no-op
869869
}
870870

871-
872-
873-
874-
private long _length(long[] shape) {
875-
// scalar case
876-
if (shape.length == 0)
877-
return 1;
878-
else if (shape.length == 1)
879-
return shape[0];
880-
else {
881-
long length = 1;
882-
for (int e = 0; e < shape.length; e++)
883-
length *= shape[e];
884-
885-
return length;
886-
}
887-
}
888-
889-
890-
891871
@Override
892872
public Map<String, CustomOpDescriptor> getCustomOperations() {
893873
throw new UnsupportedOperationException();

nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java

Lines changed: 15 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,75 +1544,44 @@ public synchronized Map<String, CustomOpDescriptor> getCustomOperations() {
15441544
return customOps;
15451545
}
15461546

1547+
15471548
/**
15481549
* This method executes given CustomOp
15491550
*
15501551
* PLEASE NOTE: You're responsible for input/output validation
1551-
* PLEASE NOTE: right now this operations are executing on CPU
1552-
* @param op
1552+
* @param op Operation to execute
15531553
*/
15541554
@Override
1555-
public INDArray[] exec(CustomOp op) {
1556-
1557-
Nd4j.getExecutioner().commit();
1558-
1559-
boolean shapeOverride = false;
1560-
if (op.numOutputArguments() == 0 && !op.isInplaceCall()) {
1561-
try {
1562-
val list = this.calculateOutputShape(op);
1563-
if (list.isEmpty())
1564-
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
1565-
1566-
for (val shape: list)
1567-
op.addOutputArgument(Nd4j.create(shape, false));
1568-
1569-
shapeOverride = true;
1570-
} catch (Exception e) {
1571-
throw new ND4JIllegalStateException("Op name " + op.opName() + " - no output arrays were provided and calculateOutputShape failed to execute", e);
1572-
}
1573-
}
1574-
1575-
1576-
1555+
public INDArray[] exec(@NonNull CustomOp op) {
15771556
val name = op.opName();
1578-
try (val context = (CudaOpContext) buildContext()) {
1579-
// optionally skip shape validation on op execution
1580-
if (shapeOverride)
1581-
context.shapeFunctionOverride(true);
1582-
1583-
context.markInplace(op.isInplaceCall());
1584-
1585-
// transferring rng state
1586-
context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState());
1587-
1588-
//transferring input/output arrays
1589-
context.setInputArrays(op.inputArguments());
1590-
context.setOutputArrays(op.outputArguments());
1591-
1592-
// transferring static args
1593-
context.setBArguments(op.bArgs());
1594-
context.setIArguments(op.iArgs());
1595-
context.setTArguments(op.tArgs());
1596-
context.setDArguments(op.dArgs());
1557+
try (val context = buildContext()) {
1558+
op.setupOpContextFromCustomOp(context);
1559+
boolean shapeOverride = op.initializeOutputs(context);
1560+
long start = profilingConfigurableHookIn(op,context);
1561+
initOpContext(op, shapeOverride, context);
15971562

15981563
val result = exec(op, context);
15991564
val states = context.getRngStates();
16001565

16011566

16021567
// pulling states back
16031568
Nd4j.getRandom().setStates(states.getFirst(), states.getSecond());
1569+
profilingConfigurableHookOut(op,context,start);
16041570

16051571
return result;
16061572
} catch (ND4JOpProfilerException e) {
1573+
16071574
throw e;
16081575
} catch (Exception e) {
1609-
StringBuilder message = new StringBuilder();
1610-
message.append("Op [" + name + "] execution failed with error " + "Cuda last error message: " + cudaGetErrorName(org.bytedeco.cuda.global.cublas.cublasGetError()).getString());
1611-
throw new RuntimeException(message.toString(), e);
1576+
throw new RuntimeException("Op [" + name + "] execution failed", e);
16121577
}
1578+
1579+
16131580
}
16141581

16151582

1583+
1584+
16161585
@Override
16171586
public ExecutionerType type() {
16181587
return ExecutionerType.CUDA;

platform-tests/src/test/java/org/eclipse/deeplearning4j/nd4j/linalg/BroadcastingOpsSmokeTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,6 @@ public void testInPlaceViewBroadcastOperations() {
10191019
// Get views
10201020
INDArray rowView = matrix.getRow(1);
10211021
INDArray colView = matrix.getColumn(1);
1022-
10231022
// Create vectors for broadcasting
10241023
INDArray rowVector = Nd4j.create(new double[] {10, 20, 30});
10251024
INDArray colVector = Nd4j.create(new double[] {10, 20, 30}).reshape(3, 1);
Original file line numberDiff line numberDiff line change
@@ -1,69 +0,0 @@
1-
/*
2-
* ******************************************************************************
3-
* *
4-
* *
5-
* * This program and the accompanying materials are made available under the
6-
* * terms of the Apache License, Version 2.0 which is available at
7-
* * https://www.apache.org/licenses/LICENSE-2.0.
8-
* *
9-
* * See the NOTICE file distributed with this work for additional
10-
* * information regarding copyright ownership.
11-
* * Unless required by applicable law or agreed to in writing, software
12-
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13-
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14-
* * License for the specific language governing permissions and limitations
15-
* * under the License.
16-
* *
17-
* * SPDX-License-Identifier: Apache-2.0
18-
* *****************************************************************************
19-
*/
20-
21-
package org.eclipse.deeplearning4j.frameworkimport.frameworkimport.onnx.loader
22-
23-
import onnx.Onnx
24-
import org.junit.jupiter.api.Assertions.assertEquals
25-
import org.junit.jupiter.api.Tag
26-
import org.junit.jupiter.api.Test
27-
import org.nd4j.common.tests.tags.TagNames
28-
import org.nd4j.samediff.frameworkimport.onnx.definitions.registry
29-
import org.nd4j.samediff.frameworkimport.onnx.process.OnnxMappingProcessLoader
30-
import org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoaderHolder
31-
import org.nd4j.samediff.frameworkimport.registry.OpMappingRegistry
32-
@Tag(TagNames.ONNX)
33-
class TestOnnxProcessLoader {
34-
35-
36-
37-
@Test
38-
fun testLoader() {
39-
val onnxOpMappingRegistry = OpMappingRegistry<Onnx.GraphProto, Onnx.NodeProto,
40-
Onnx.NodeProto, Onnx.TensorProto,
41-
Onnx.TensorProto.DataType, Onnx.AttributeProto, Onnx.AttributeProto>(
42-
"onnx", OpDescriptorLoaderHolder.nd4jOpDescriptor)
43-
44-
val loader = OnnxMappingProcessLoader(onnxOpMappingRegistry)
45-
println(loader)
46-
registry().inputFrameworkOpNames().forEach { name ->
47-
if(registry().hasMappingOpProcess(name)) {
48-
val process = registry().lookupOpMappingProcess(name)
49-
val serialized = process.serialize()
50-
val created = loader.createProcess(serialized)
51-
assertEquals(
52-
process,
53-
created,
54-
"Op name $name failed with process tensor rules ${process.tensorMappingRules()} and created tensor rules ${created.tensorMappingRules()} with attributes ${process.attributeMappingRules()} and created attribute rules ${created.attributeMappingRules()}",
55-
56-
)
57-
}
58-
59-
}
60-
}
61-
62-
@Test
63-
fun saveTest() {
64-
registry().saveProcessesAndRuleSet()
65-
val loader = OnnxMappingProcessLoader(registry())
66-
registry().loadFromFile("onnx-processes.pbtxt",loader)
67-
68-
}
69-
}
Original file line numberDiff line numberDiff line change
@@ -1,63 +0,0 @@
1-
/*
2-
* ******************************************************************************
3-
* *
4-
* *
5-
* * This program and the accompanying materials are made available under the
6-
* * terms of the Apache License, Version 2.0 which is available at
7-
* * https://www.apache.org/licenses/LICENSE-2.0.
8-
* *
9-
* * See the NOTICE file distributed with this work for additional
10-
* * information regarding copyright ownership.
11-
* * Unless required by applicable law or agreed to in writing, software
12-
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13-
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14-
* * License for the specific language governing permissions and limitations
15-
* * under the License.
16-
* *
17-
* * SPDX-License-Identifier: Apache-2.0
18-
* *****************************************************************************
19-
*/
20-
21-
package org.eclipse.deeplearning4j.frameworkimport.frameworkimport.tensorflow.loader
22-
23-
import org.junit.jupiter.api.Assertions.assertEquals
24-
import org.junit.jupiter.api.Tag
25-
import org.junit.jupiter.api.Test
26-
import org.nd4j.common.tests.tags.TagNames
27-
import org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoaderHolder
28-
import org.nd4j.samediff.frameworkimport.registry.OpMappingRegistry
29-
import org.nd4j.samediff.frameworkimport.tensorflow.definitions.registry
30-
import org.nd4j.samediff.frameworkimport.tensorflow.process.TensorflowMappingProcessLoader
31-
import org.tensorflow.framework.*
32-
33-
@Tag(TagNames.TENSORFLOW)
34-
class TestTensorflowProcessLoader {
35-
36-
@Test
37-
fun testLoader() {
38-
val tensorflowOpMappingRegistry = OpMappingRegistry<GraphDef, NodeDef, OpDef, TensorProto, DataType, OpDef.AttrDef, AttrValue>(
39-
"tensorflow", OpDescriptorLoaderHolder.nd4jOpDescriptor)
40-
41-
val loader = TensorflowMappingProcessLoader(tensorflowOpMappingRegistry)
42-
println(loader)
43-
registry().inputFrameworkOpNames().forEach { name ->
44-
if(registry().hasMappingOpProcess(name)) {
45-
val process = registry().lookupOpMappingProcess(name)
46-
val serialized = process.serialize()
47-
val created = loader.createProcess(serialized)
48-
assertEquals(process,created,"Op name $name failed with process tensor rules ${process.tensorMappingRules()} and created tensor rules ${created.tensorMappingRules()} with attributes ${process.attributeMappingRules()} and created attribute rules ${created.attributeMappingRules()}")
49-
}
50-
51-
}
52-
53-
}
54-
55-
56-
57-
58-
@Test
59-
fun saveTest() {
60-
registry().saveProcessesAndRuleSet()
61-
}
62-
63-
}
Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +0,0 @@
1-
package org.nd4j.samediff.frameworkimport.reflect
2-
3-
import org.apache.commons.io.FileUtils
4-
import org.junit.jupiter.api.Assertions.assertEquals
5-
import org.junit.jupiter.api.Assertions.assertNotNull
6-
import org.junit.jupiter.api.Disabled
7-
import org.junit.jupiter.api.Test
8-
import java.io.File
9-
import java.nio.charset.Charset
10-
import kotlin.test.assertTrue
11-
12-
13-
class ClassGraphHolderTest {
14-
15-
@Test
16-
@Disabled("Takes too long to run.")
17-
fun testClassGraphHolder() {
18-
val jsonFile = File("scanned-classes.json")
19-
ClassGraphHolder.saveScannedClasses(jsonFile)
20-
val original = ClassGraphHolder.scannedClasses
21-
val loadedJson = FileUtils.readFileToString(jsonFile, Charset.defaultCharset())
22-
assertTrue(loadedJson.length > 1,"Json was not written and is empty")
23-
val loaded = ClassGraphHolder.loadFromJson(loadedJson)
24-
assertEquals(original.toJSON(),loaded.toJSON())
25-
assertNotNull(ClassGraphHolder.scannedClasses)
26-
}
27-
28-
}

0 commit comments

Comments
 (0)