Skip to content

Commit cbe9e2b

Browse files
committed
Update
1 parent adc546e commit cbe9e2b

File tree

4 files changed

+15
-43
lines changed

4 files changed

+15
-43
lines changed

torch/csrc/utils/python_arg_parser.cpp

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,7 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
131131
size(0),
132132
default_scalar(0) {
133133
auto space = fmt.find(' ');
134-
if (space == std::string::npos) {
135-
TORCH_CHECK(false, "FunctionParameter(): missing type: " + fmt);
136-
}
134+
TORCH_CHECK(space != std::string::npos, "FunctionParameter(): missing type: " + fmt);
137135

138136
auto type_str = fmt.substr(0, space);
139137

@@ -154,9 +152,7 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
154152

155153
auto name_str = fmt.substr(space + 1);
156154
auto it = type_map.find(type_str);
157-
if (it == type_map.end()) {
158-
TORCH_CHECK(false, "FunctionParameter(): invalid type string: " + type_str);
159-
}
155+
TORCH_CHECK(it != type_map.end(), "FunctionParameter(): invalid type string: " + type_str);
160156
type_ = it->second;
161157

162158
auto eq = name_str.find('=');
@@ -1226,9 +1222,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
12261222
}
12271223
if (type_ == ParameterType::TENSOR ||
12281224
type_ == ParameterType::DISPATCH_KEY_SET) {
1229-
if (str != "None") {
1230-
TORCH_CHECK(false, "default value for Tensor must be none, got: " + str);
1231-
}
1225+
TORCH_CHECK(str == "None", "default value for Tensor must be none, got: " + str);
12321226
} else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) {
12331227
default_int = atol(str.c_str());
12341228
} else if (type_ == ParameterType::BOOL) {
@@ -1252,9 +1246,7 @@ void FunctionParameter::set_default_str(const std::string& str) {
12521246
default_intlist = parse_intlist_args(str, size);
12531247
}
12541248
} else if (type_ == ParameterType::FLOAT_LIST) {
1255-
if (str != "None") {
1256-
TORCH_CHECK(false, "Defaults not supported for float[]");
1257-
}
1249+
TORCH_CHECK(str == "None", "Defaults not supported for float[]");
12581250
} else if (type_ == ParameterType::SCALARTYPE) {
12591251
if (str == "None") {
12601252
default_scalartype = at::ScalarType::Undefined;
@@ -1274,13 +1266,9 @@ void FunctionParameter::set_default_str(const std::string& str) {
12741266
TORCH_CHECK(false, "invalid default value for layout: " + str);
12751267
}
12761268
} else if (type_ == ParameterType::DEVICE) {
1277-
if (str != "None") {
1278-
TORCH_CHECK(false, "invalid device: " + str);
1279-
}
1269+
TORCH_CHECK(str == "None", "invalid device: " + str);
12801270
} else if (type_ == ParameterType::STREAM) {
1281-
if (str != "None") {
1282-
TORCH_CHECK(false, "invalid stream: " + str);
1283-
}
1271+
TORCH_CHECK(str == "None", "invalid stream: " + str);
12841272
} else if (type_ == ParameterType::STRING) {
12851273
if (str != "None") {
12861274
default_string = parse_string_literal(str);
@@ -1346,12 +1334,8 @@ FunctionSignature::FunctionSignature(const std::string& fmt, int index)
13461334
break;
13471335
}
13481336
}
1349-
if (offset == std::string::npos) {
1350-
TORCH_CHECK(false, "missing closing parenthesis: " + fmt);
1351-
}
1352-
if (offset == last_offset) {
1353-
TORCH_CHECK(false, "malformed signature: " + fmt);
1354-
}
1337+
TORCH_CHECK(offset != std::string::npos, "missing closing parenthesis: " + fmt);
1338+
TORCH_CHECK(offset != last_offset, "malformed signature: " + fmt);
13551339

13561340
auto param_str = fmt.substr(last_offset, offset - last_offset);
13571341
last_offset = next_offset;

torch/csrc/utils/python_numbers.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,8 @@ inline c10::DeviceIndex THPUtils_unpackDeviceIndex(PyObject* obj) {
199199
if (value == -1 && PyErr_Occurred()) {
200200
throw python_error();
201201
}
202-
if (overflow != 0) {
203-
TORCH_CHECK(false, "Overflow when unpacking DeviceIndex");
204-
}
205-
if (value > std::numeric_limits<c10::DeviceIndex>::max() ||
206-
value < std::numeric_limits<c10::DeviceIndex>::min()) {
207-
TORCH_CHECK(false, "Overflow when unpacking DeviceIndex");
208-
}
202+
TORCH_CHECK(overflow == 0, "Overflow when unpacking DeviceIndex");
203+
TORCH_CHECK(value <= std::numeric_limits<c10::DeviceIndex>::max() &&
204+
value >= std::numeric_limits<c10::DeviceIndex>::min(), "Overflow when unpacking DeviceIndex");
209205
return (c10::DeviceIndex)value;
210206
}

torch/csrc/utils/python_strings.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ inline std::string THPUtils_unpackString(PyObject* obj) {
2626
if (PyUnicode_Check(obj)) {
2727
Py_ssize_t size = 0;
2828
const char* data = PyUnicode_AsUTF8AndSize(obj, &size);
29-
if (!data) {
30-
TORCH_CHECK(false, "error unpacking string as utf-8");
31-
}
29+
TORCH_CHECK(data, "error unpacking string as utf-8");
3230
return std::string(data, (size_t)size);
3331
}
3432
TORCH_CHECK(false, "unpackString: expected bytes or unicode object");
@@ -50,9 +48,7 @@ inline std::string_view THPUtils_unpackStringView(PyObject* obj) {
5048
if (PyUnicode_Check(obj)) {
5149
Py_ssize_t size = 0;
5250
const char* data = PyUnicode_AsUTF8AndSize(obj, &size);
53-
if (!data) {
54-
TORCH_CHECK(false, "error unpacking string as utf-8");
55-
}
51+
TORCH_CHECK(data, "error unpacking string as utf-8");
5652
return std::string_view(data, (size_t)size);
5753
}
5854
TORCH_CHECK(false, "unpackString: expected bytes or unicode object");

torch/csrc/utils/tensor_numpy.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,7 @@ void warn_numpy_not_writeable() {
213213
at::Tensor tensor_from_numpy(
214214
PyObject* obj,
215215
bool warn_if_not_writeable /*=true*/) {
216-
if (!is_numpy_available()) {
217-
TORCH_CHECK(false, "Numpy is not available");
218-
}
216+
TORCH_CHECK(is_numpy_available(), "Numpy is not available");
219217
TORCH_CHECK_TYPE(
220218
PyArray_Check(obj),
221219
"expected np.ndarray (got ",
@@ -381,9 +379,7 @@ bool is_numpy_scalar(PyObject* obj) {
381379
}
382380

383381
at::Tensor tensor_from_cuda_array_interface(PyObject* obj) {
384-
if (!is_numpy_available()) {
385-
TORCH_CHECK(false, "Numpy is not available");
386-
}
382+
TORCH_CHECK(is_numpy_available(), "Numpy is not available");
387383
auto cuda_dict =
388384
THPObjectPtr(PyObject_GetAttrString(obj, "__cuda_array_interface__"));
389385
TORCH_INTERNAL_ASSERT(cuda_dict);

0 commit comments

Comments
 (0)