Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 6e78529

Browse files
[OpenVINO backend] Support numpy.prod (#21567)
* feat: prod numpy for openvino backend * feat: included tests for prod * fix: handled dtype permosion
1 parent 5dbdf60 commit 6e78529

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

‎keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ NumpyDtypeTest::test_median
4848
NumpyDtypeTest::test_minimum_python_types
4949
NumpyDtypeTest::test_multiply
5050
NumpyDtypeTest::test_power
51-
NumpyDtypeTest::test_prod
5251
NumpyDtypeTest::test_quantile
5352
NumpyDtypeTest::test_roll
5453
NumpyDtypeTest::test_round
@@ -107,7 +106,6 @@ NumpyOneInputOpsCorrectnessTest::test_pad_int16_constant_2
107106
NumpyOneInputOpsCorrectnessTest::test_pad_int8_constant_2
108107
NumpyOneInputOpsCorrectnessTest::test_pad_uint8_constant_2
109108
NumpyOneInputOpsCorrectnessTest::test_pad_int32_constant_2
110-
NumpyOneInputOpsCorrectnessTest::test_prod
111109
NumpyOneInputOpsCorrectnessTest::test_real
112110
NumpyOneInputOpsCorrectnessTest::test_reshape
113111
NumpyOneInputOpsCorrectnessTest::test_roll

‎keras/src/backend/openvino/numpy.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1403,7 +1403,35 @@ def pad(x, pad_width, mode="constant", constant_values=None):
14031403

14041404

14051405
def prod(x, axis=None, keepdims=False, dtype=None):
1406-
raise NotImplementedError("`prod` is not supported with openvino backend")
1406+
x = get_ov_output(x)
1407+
1408+
# If a specific dtype is requested, cast the input to that dtype.
1409+
if dtype is not None:
1410+
ov_dtype = OPENVINO_DTYPES[standardize_dtype(dtype)]
1411+
x = ov_opset.convert(x, ov_dtype).output(0)
1412+
# Otherwise, apply dtype promotion rules before reduction.
1413+
else:
1414+
x_type = x.get_element_type()
1415+
if x_type == Type.boolean:
1416+
x = ov_opset.convert(x, Type.i32).output(0)
1417+
elif x_type in (Type.i8, Type.i16):
1418+
x = ov_opset.convert(x, Type.i32).output(0)
1419+
elif x_type in (Type.u8, Type.u16):
1420+
x = ov_opset.convert(x, Type.u32).output(0)
1421+
1422+
if axis is None:
1423+
flatten_shape = ov_opset.constant([-1], Type.i32).output(0)
1424+
x = ov_opset.reshape(x, flatten_shape, False).output(0)
1425+
axis = 0
1426+
1427+
if isinstance(axis, tuple):
1428+
axis = list(axis)
1429+
axis = ov_opset.constant(axis, Type.i32).output(0)
1430+
1431+
# Compute the product
1432+
result = ov_opset.reduce_prod(x, axis, keepdims).output(0)
1433+
1434+
return OpenVINOKerasTensor(result)
14071435

14081436

14091437
def quantile(x, q, axis=None, method="linear", keepdims=False):

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /