diff --git a/onnx_array_api/plotting/text_plot.py b/onnx_array_api/plotting/text_plot.py index 0b4d30a..f5ac963 100644 --- a/onnx_array_api/plotting/text_plot.py +++ b/onnx_array_api/plotting/text_plot.py @@ -64,7 +64,10 @@ def __init__(self, i, atts): self.nodes_missing_value_tracks_true = None for k, v in atts.items(): if k.startswith("nodes"): - setattr(self, k, v[i]) + if k.endswith("_as_tensor"): + setattr(self, k.replace("_as_tensor", ""), v[i]) + else: + setattr(self, k, v[i]) self.depth = 0 self.true_false = "" self.targets = [] @@ -120,10 +123,7 @@ def process_tree(atts, treeid): ] for k, v in atts.items(): if k.startswith(prefix): - if "classlabels" in k: - short[k] = list(v) - else: - short[k] = [v[i] for i in idx] + short[k] = list(v) if "classlabels" in k else [v[i] for i in idx] nodes = OrderedDict() for i in range(len(short["nodes_treeids"])): @@ -132,9 +132,10 @@ def process_tree(atts, treeid): for i in range(len(short[f"{prefix}_treeids"])): idn = short[f"{prefix}_nodeids"][i] node = nodes[idn] - node.append_target( - tid=short[f"{prefix}_ids"][i], weight=short[f"{prefix}_weights"][i] - ) + key = f"{prefix}_weights" + if key not in short: + key = f"{prefix}_weights_as_tensor" + node.append_target(tid=short[f"{prefix}_ids"][i], weight=short[key][i]) def iterate(nodes, node, depth=0, true_false=""): node.depth = depth