@@ -91,15 +91,32 @@ class DecodeBmpOp : public OpKernel {
91
91
errors::InvalidArgument (
92
92
" Number of channels must be 1, 3 or 4, was " , channels_));
93
93
94
+ OP_REQUIRES (context, width > 0 && header_size >= 0 ,
95
+ errors::InvalidArgument (" Width must be positive" ));
96
+ OP_REQUIRES (context, header_size >= 0 ,
97
+ errors::InvalidArgument (" header size must be nonnegative" ));
98
+
99
+ // The real requirement is < 2^31 minus some headers and channel data,
100
+ // so rounding down to something that's still ridiculously big.
101
+ OP_REQUIRES (
102
+ context,
103
+ (static_cast <int64>(width) * std::abs (static_cast <int64>(height))) <
104
+ static_cast <int64>(std::numeric_limits<int32_t >::max () / 8 ),
105
+ errors::InvalidArgument (
106
+ " Total possible pixel bytes must be less than 2^30" ));
107
+
108
+ const int32 abs_height = abs (height);
109
+
94
110
// there may be padding bytes when the width is not a multiple of 4 bytes
95
111
// 8 * channels == bits per pixel
96
112
const int row_size = (8 * channels_ * width + 31 ) / 32 * 4 ;
97
113
98
- const int last_pixel_offset =
99
- header_size + (abs (height) - 1 ) * row_size + (width - 1 ) * channels_;
114
+ const int64 last_pixel_offset = static_cast <int64>(header_size) +
115
+ (abs_height - 1 ) * row_size +
116
+ (width - 1 ) * channels_;
100
117
101
118
// [expected file size] = [last pixel offset] + [last pixel size=channels]
102
- const int expected_file_size = last_pixel_offset + channels_;
119
+ const int64 expected_file_size = last_pixel_offset + channels_;
103
120
104
121
OP_REQUIRES (
105
122
context, (expected_file_size <= input.size ()),
@@ -115,12 +132,12 @@ class DecodeBmpOp : public OpKernel {
115
132
Tensor* output = nullptr ;
116
133
OP_REQUIRES_OK (
117
134
context, context->allocate_output (
118
- 0 , TensorShape ({abs (height) , width, channels_}), &output));
135
+ 0 , TensorShape ({abs_height , width, channels_}), &output));
119
136
120
137
const uint8* bmp_pixels = &img_bytes[header_size];
121
138
122
139
Decode (bmp_pixels, row_size, output->flat <uint8>().data (), width,
123
- abs (height) , channels_, top_down);
140
+ abs_height , channels_, top_down);
124
141
}
125
142
126
143
uint8* Decode (const uint8* input, const int row_size, uint8* const output,
0 commit comments