Skip to content

Commit 1fc0aec

Browse files
committed
Fix reference count if array used in JIT operations.
Previously when an af::array was used in a jit operation and it was backed by a buffer, a buffer node was created and the internal shared_ptr was stored in the Array for future use and returned when getNode was called. This increased the reference count of the internal buffer. This reference count never decreased because of the internal reference to the shared_ptr. This commit changes this behavior by createing new buffer nodes for each call the getNode. We use the new hash function to ensure the equality of the buffer node when the jit code is generated. This avoids holding the call_once flag in the buffer object and simplifies the management of the buffer node objects. Additionally when a jit node goes out of scope the reference count decrements as expected.
1 parent 92d0a67 commit 1fc0aec

File tree

19 files changed

+264
-212
lines changed

19 files changed

+264
-212
lines changed

src/backend/common/jit/BufferNodeBase.hpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,20 @@ class BufferNodeBase : public common::Node {
2424
DataType m_data;
2525
ParamType m_param;
2626
unsigned m_bytes;
27-
std::once_flag m_set_data_flag;
2827
bool m_linear_buffer;
2928

3029
public:
31-
BufferNodeBase(af::dtype type) : Node(type, 0, {}) {
32-
// This class is not movable because of std::once_flag
33-
}
30+
BufferNodeBase(af::dtype type)
31+
: Node(type, 0, {}), m_bytes(0), m_linear_buffer(true) {}
3432

3533
bool isBuffer() const final { return true; }
3634

3735
void setData(ParamType param, DataType data, const unsigned bytes,
3836
bool is_linear) {
39-
std::call_once(m_set_data_flag,
40-
[this, param, data, bytes, is_linear]() {
41-
m_param = param;
42-
m_data = data;
43-
m_bytes = bytes;
44-
m_linear_buffer = is_linear;
45-
});
37+
m_param = param;
38+
m_data = data;
39+
m_bytes = bytes;
40+
m_linear_buffer = is_linear;
4641
}
4742

4843
bool isLinear(dim_t dims[4]) const final {

src/backend/common/jit/Node.hpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
#include <types.hpp>
1616
#include <af/defines.h>
1717

18+
#include <algorithm>
1819
#include <array>
1920
#include <cstring>
2021
#include <functional>
2122
#include <memory>
23+
#include <sstream>
2224
#include <string>
2325
#include <unordered_map>
24-
#include <utility>
2526
#include <vector>
2627

2728
enum class kJITHeuristics {
@@ -107,15 +108,6 @@ class Node {
107108
template<typename T>
108109
friend class NodeIterator;
109110

110-
void swap(Node &other) noexcept {
111-
using std::swap;
112-
for (int i = 0; i < kMaxChildren; i++) {
113-
swap(m_children[i], other.m_children[i]);
114-
}
115-
swap(m_type, other.m_type);
116-
swap(m_height, other.m_height);
117-
}
118-
119111
public:
120112
Node() = default;
121113
Node(const af::dtype type, const int height,
@@ -125,6 +117,15 @@ class Node {
125117
"Node is not move assignable");
126118
}
127119

120+
void swap(Node &other) noexcept {
121+
using std::swap;
122+
for (int i = 0; i < kMaxChildren; i++) {
123+
std::swap(m_children[i], other.m_children[i]);
124+
}
125+
std::swap(m_type, other.m_type);
126+
std::swap(m_height, other.m_height);
127+
}
128+
128129
/// Default move constructor operator
129130
Node(Node &&node) noexcept = default;
130131

src/backend/cpu/Array.cpp

Lines changed: 27 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,19 @@ using common::Node_map_t;
4343
using common::Node_ptr;
4444
using common::NodeIterator;
4545
using cpu::jit::BufferNode;
46+
4647
using std::adjacent_find;
4748
using std::copy;
4849
using std::is_standard_layout;
50+
using std::make_shared;
4951
using std::move;
5052
using std::vector;
5153

5254
namespace cpu {
5355

5456
template<typename T>
55-
Node_ptr bufferNodePtr() {
56-
return Node_ptr(reinterpret_cast<Node *>(new BufferNode<T>()));
57+
shared_ptr<BufferNode<T>> bufferNodePtr() {
58+
return std::make_shared<BufferNode<T>>();
5759
}
5860

5961
template<typename T>
@@ -62,8 +64,7 @@ Array<T>::Array(dim4 dims)
6264
static_cast<af_dtype>(dtype_traits<T>::af_type))
6365
, data(memAlloc<T>(dims.elements()).release(), memFree<T>)
6466
, data_dims(dims)
65-
, node(bufferNodePtr<T>())
66-
, ready(true)
67+
, node()
6768
, owner(true) {}
6869

6970
template<typename T>
@@ -75,8 +76,7 @@ Array<T>::Array(const dim4 &dims, T *const in_data, bool is_device,
7576
: memAlloc<T>(dims.elements()).release(),
7677
memFree<T>)
7778
, data_dims(dims)
78-
, node(bufferNodePtr<T>())
79-
, ready(true)
79+
, node()
8080
, owner(true) {
8181
static_assert(is_standard_layout<Array<T>>::value,
8282
"Array<T> must be a standard layout type");
@@ -101,7 +101,6 @@ Array<T>::Array(const af::dim4 &dims, Node_ptr n)
101101
, data()
102102
, data_dims(dims)
103103
, node(move(n))
104-
, ready(false)
105104
, owner(true) {}
106105

107106
template<typename T>
@@ -111,8 +110,7 @@ Array<T>::Array(const Array<T> &parent, const dim4 &dims, const dim_t &offset_,
111110
static_cast<af_dtype>(dtype_traits<T>::af_type))
112111
, data(parent.getData())
113112
, data_dims(parent.getDataDims())
114-
, node(bufferNodePtr<T>())
115-
, ready(true)
113+
, node()
116114
, owner(false) {}
117115

118116
template<typename T>
@@ -123,8 +121,7 @@ Array<T>::Array(const dim4 &dims, const dim4 &strides, dim_t offset_,
123121
, data(is_device ? in_data : memAlloc<T>(info.total()).release(),
124122
memFree<T>)
125123
, data_dims(dims)
126-
, node(bufferNodePtr<T>())
127-
, ready(true)
124+
, node()
128125
, owner(true) {
129126
if (!is_device) {
130127
// Ensure the memory being written to isnt used anywhere else.
@@ -135,40 +132,27 @@ Array<T>::Array(const dim4 &dims, const dim4 &strides, dim_t offset_,
135132

136133
template<typename T>
137134
void Array<T>::eval() {
138-
if (isReady()) { return; }
139-
if (getQueue().is_worker()) {
140-
AF_ERROR("Array not evaluated", AF_ERR_INTERNAL);
141-
}
142-
143-
this->setId(getActiveDeviceId());
144-
145-
data = shared_ptr<T>(memAlloc<T>(elements()).release(), memFree<T>);
146-
147-
getQueue().enqueue(kernel::evalArray<T>, *this, this->node);
148-
// Reset shared_ptr
149-
this->node = bufferNodePtr<T>();
150-
ready = true;
135+
evalMultiple<T>({this});
151136
}
152137

153138
template<typename T>
154139
void Array<T>::eval() const {
155-
if (isReady()) { return; }
156140
const_cast<Array<T> *>(this)->eval();
157141
}
158142

159143
template<typename T>
160144
T *Array<T>::device() {
161-
getQueue().sync();
162145
if (!isOwner() || getOffset() || data.use_count() > 1) {
163146
*this = copyArray<T>(*this);
164147
}
148+
getQueue().sync();
165149
return this->get();
166150
}
167151

168152
template<typename T>
169153
void evalMultiple(vector<Array<T> *> array_ptrs) {
170154
vector<Array<T> *> outputs;
171-
vector<Node_ptr> nodes;
155+
vector<common::Node_ptr> nodes;
172156
vector<Param<T>> params;
173157
if (getQueue().is_worker()) {
174158
AF_ERROR("Array not evaluated", AF_ERR_INTERNAL);
@@ -187,41 +171,39 @@ void evalMultiple(vector<Array<T> *> array_ptrs) {
187171
}
188172

189173
for (Array<T> *array : array_ptrs) {
190-
if (array->ready) { continue; }
174+
if (array->isReady()) { continue; }
191175

192176
array->setId(getActiveDeviceId());
193177
array->data =
194178
shared_ptr<T>(memAlloc<T>(array->elements()).release(), memFree<T>);
195179

196180
outputs.push_back(array);
197-
params.push_back(*array);
181+
params.emplace_back(array->getData().get(), array->dims(),
182+
array->strides());
198183
nodes.push_back(array->node);
199184
}
200185

201-
if (!outputs.empty()) {
202-
getQueue().enqueue(kernel::evalMultiple<T>, params, nodes);
203-
for (Array<T> *array : outputs) {
204-
array->ready = true;
205-
array->node = bufferNodePtr<T>();
206-
}
207-
}
186+
if (params.empty()) return;
187+
188+
getQueue().enqueue(cpu::kernel::evalMultiple<T>, params, nodes);
189+
190+
for (Array<T> *array : outputs) { array->node.reset(); }
208191
}
209192

210193
template<typename T>
211194
Node_ptr Array<T>::getNode() {
212-
if (node->isBuffer()) {
213-
auto *bufNode = reinterpret_cast<BufferNode<T> *>(node.get());
214-
unsigned bytes = this->getDataDims().elements() * sizeof(T);
215-
bufNode->setData(data, bytes, getOffset(), dims().get(),
216-
strides().get(), isLinear());
217-
}
218-
return node;
195+
if (node) { return node; }
196+
197+
std::shared_ptr<BufferNode<T>> out = bufferNodePtr<T>();
198+
unsigned bytes = this->getDataDims().elements() * sizeof(T);
199+
out->setData(data, bytes, getOffset(), dims().get(), strides().get(),
200+
isLinear());
201+
return out;
219202
}
220203

221204
template<typename T>
222205
Node_ptr Array<T>::getNode() const {
223-
if (node->isBuffer()) { return const_cast<Array<T> *>(this)->getNode(); }
224-
return node;
206+
return const_cast<Array<T> *>(this)->getNode();
225207
}
226208

227209
template<typename T>
@@ -337,7 +319,6 @@ template<typename T>
337319
void Array<T>::setDataDims(const dim4 &new_dims) {
338320
modDims(new_dims);
339321
data_dims = new_dims;
340-
if (node->isBuffer()) { node = bufferNodePtr<T>(); }
341322
}
342323

343324
#define INSTANTIATE(T) \

src/backend/cpu/Array.hpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@
2828
#include <vector>
2929

3030
namespace cpu {
31+
32+
namespace jit {
33+
template<typename T>
34+
class BufferNode;
35+
}
36+
3137
namespace kernel {
3238
template<typename T>
3339
void evalArray(Param<T> in, common::Node_ptr node);
@@ -115,15 +121,23 @@ template<typename T>
115121
class Array {
116122
ArrayInfo info; // Must be the first element of Array<T>
117123

118-
// data if parent. empty if child
124+
/// Pointer to the data
119125
std::shared_ptr<T> data;
126+
127+
/// The shape of the underlying parent data.
120128
af::dim4 data_dims;
129+
130+
/// Null if this a buffer node. Otherwise this points to a JIT node
121131
common::Node_ptr node;
122132

123-
bool ready;
133+
/// If true, the Array object is the parent. If false the data object points
134+
/// to another array's data
124135
bool owner;
125136

137+
/// Default constructor
126138
Array() = default;
139+
140+
/// Creates an uninitialized array of a specific shape
127141
Array(dim4 dims);
128142

129143
explicit Array(const af::dim4 &dims, T *const in_data, bool is_device,
@@ -149,7 +163,6 @@ class Array {
149163
swap(data, other.data);
150164
swap(data_dims, other.data_dims);
151165
swap(node, other.node);
152-
swap(ready, other.ready);
153166
swap(owner, other.owner);
154167
}
155168

@@ -198,7 +211,7 @@ class Array {
198211

199212
~Array() = default;
200213

201-
bool isReady() const { return ready; }
214+
bool isReady() const { return static_cast<bool>(node) == false; }
202215

203216
bool isOwner() const { return owner; }
204217

@@ -236,10 +249,7 @@ class Array {
236249
return data.get() + (withOffset ? getOffset() : 0);
237250
}
238251

239-
int useCount() const {
240-
if (!data.get()) eval();
241-
return static_cast<int>(data.use_count());
242-
}
252+
int useCount() const { return static_cast<int>(data.use_count()); }
243253

244254
operator Param<T>() {
245255
return Param<T>(this->get(), this->dims(), this->strides());

src/backend/cpu/binary.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
********************************************************/
99
#pragma once
1010

11+
#include <jit/Node.hpp>
1112
#include <math.hpp>
1213
#include <optypes.hpp>
1314
#include <types.hpp>

src/backend/cpu/jit/BinaryNode.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ class BinaryNode : public TNode<compute_t<To>> {
4343
m_op.eval(this->m_val, m_lhs->m_val, m_rhs->m_val, lim);
4444
}
4545

46+
void replaceChild(int id, void *ptr) final {
47+
if (id == 0) { m_lhs = static_cast<TNode<compute_t<Ti>> *>(ptr); }
48+
if (id == 1) { m_rhs = static_cast<TNode<compute_t<Ti>> *>(ptr); }
49+
}
50+
4651
void calc(int idx, int lim) final {
4752
UNUSED(idx);
4853
m_op.eval(this->m_val, m_lhs->m_val, m_rhs->m_val, lim);

0 commit comments

Comments
 (0)