Skip to content

Commit ac9069d

Browse files
committed
debugging test_cr_history_indexing.py file
1 parent 579fe22 commit ac9069d

File tree

2 files changed

+123
-2
lines changed

2 files changed

+123
-2
lines changed

mltrace/db/store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,7 @@ def get_component_runs_by_index(
545545
"""Gets lineage for the component, or a history of all its runs."""
546546

547547
if first_idx < 0 or last_idx < 0:
548-
total_length = self.session.query(ComponentRun).count()
548+
total_length = self.get_component_runs_count(component_name)
549549
if first_idx < 0:
550550
first_idx = first_idx + total_length
551551
if last_idx < 0:
@@ -562,7 +562,7 @@ def get_component_runs_by_index(
562562
return history
563563

564564
def get_component_runs_count(self, component_name: str):
565-
return self.session.query(ComponentRun).count()
565+
return self.session.query(ComponentRun).filter(ComponentRun.component_name == component_name).count()
566566

567567
def get_components(self, tag: str = "", owner: str = ""):
568568
"""Returns a list of all the components associated with the specified

tests/test_cr_history_indexing.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from mltrace import Component
2+
import unittest
3+
from mltrace import (
4+
set_db_uri,
5+
)
6+
from mltrace.db import Store
7+
8+
9+
class TestHistoryComponent(Component):
10+
def __init__(
11+
self,
12+
name="",
13+
owner="",
14+
description="",
15+
beforeTests=[],
16+
afterTests=[],
17+
tags=[],
18+
):
19+
super().__init__(
20+
name, owner, description, beforeTests, afterTests, tags)
21+
22+
23+
def isEqualComponentRun(crOne, crTwo):
24+
if crOne.start_timestamp == crTwo.start_timestamp and crOne.end_timestamp == crTwo.end_timestamp:
25+
return True
26+
return False
27+
28+
29+
class TestComponentRunHistory(unittest.TestCase):
30+
def setUp(self):
31+
set_db_uri("test")
32+
self.new_component = TestHistoryComponent("history_test_six", "boyuan")
33+
34+
@self.new_component.run(auto_log=True)
35+
def function(num):
36+
num += 2
37+
return num ** 2
38+
39+
# initialize four component run for TestHistoryComponent
40+
for i in range (0, 4):
41+
function(i)
42+
43+
self.historyLength = len(self.new_component.history)
44+
self.firstComponentRun = self.new_component.history.get_runs_by_index(0, 1)
45+
self.lastComponentRun = self.new_component.history.get_runs_by_index(self.historyLength - 1, self.historyLength)[0]
46+
self.secondAndThirdComponentRun = self.new_component.history.get_runs_by_index(1, 3)
47+
48+
def testPstPstIndex(self):
49+
50+
# case 1: (0, 0) return zero componentRun
51+
resCrList = self.new_component.history.get_runs_by_index(0, 0)
52+
self.assertEqual(len(resCrList), 0)
53+
54+
# case 2: (0, 1) return the first componentRun
55+
resCrList = self.new_component.history.get_runs_by_index(0, 1)
56+
self.assertEqual(len(resCrList), 1)
57+
self.assertTrue(isEqualComponentRun(resCrList[0], self.firstComponentRun))
58+
59+
# case 3: (len(componentRun) - 1, len(componentRun)) return the last componentRun
60+
resCrList = self.new_component.history.get_runs_by_index(self.historyLength - 1, self.historyLength)
61+
self.assertEqual(len(resCrList), 1)
62+
self.assertEqual(isEqualComponentRun(resCrList[0], self.lastComponentRun))
63+
64+
# case 4: (1, 3) return the second and third componentRun given len(componentRun) >=3)
65+
resCrList = self.new_component.history.get_runs_by_index(1, 3)
66+
self.assertEqual(len(resCrList), 2)
67+
for idx, cr in enumerate(resCrList):
68+
if idx >= len(self.secondAndThirdComponentRun):
69+
break
70+
else:
71+
self.assertEqual(isEqualComponentRun(cr, self.secondAndThirdComponentRun[idx]))
72+
73+
74+
def testPstNgtIndex(self):
75+
76+
# case 1: (0, -len(componentRun) + 1) return first componentRun)
77+
resCrList = self.new_component.history.get_runs_by_index(0, -self.historyLength + 1)
78+
self.assertEqual(len(resCrList), 1)
79+
self.assertEqual(isEqualComponentRun(resCrList[0], self.firstComponentRun))
80+
81+
# case 2: cannot retrieve last componentRun this way
82+
83+
# case 3: (1, -len(componentRun) + 3) return second and third componentRun given len(componentRun) >=4)
84+
resCrList = self.new_component.history.get_runs_by_index(1, -self.historyLength + 3)
85+
self.assertEqual(len(resCrList), 2)
86+
for idx, cr in enumerate(resCrList):
87+
self.assertEqual(isEqualComponentRun(cr, self.secondAndThirdComponentRun[idx]))
88+
89+
def testNgtPstIndex(self):
90+
91+
# case 1: (-len(componentRun), 1) return the first componentRun)
92+
resCrList = self.new_component.history.get_runs_by_index(-self.historyLength, 1)
93+
self.assertEqual(len(resCrList), 1)
94+
self.assertEqual(isEqualComponentRun(resCrList[0], self.firstComponentRun))
95+
96+
# case 2: (-1, len(componentRun)) return the last componentRun)
97+
resCrList = self.new_component.history.get_runs_by_index(-1, self.historyLength)
98+
self.assertEqual(len(resCrList),)
99+
self.assertEqual(isEqualComponentRun(resCrList[0], self.lastComponentRun))
100+
101+
# case 3: (-len(componentRun) + 1, 3) return second and third componentRun given len(componentRun) >=3)
102+
resCrList = self.new_component.history.get_runs_by_index(-self.historyLength + 1, 3)
103+
self.assertEqual(len(resCrList), 2)
104+
for idx, cr in enumerate(resCrList):
105+
self.assertEqual(isEqualComponentRun(cr, self.secondAndThirdComponentRun[idx]))
106+
107+
108+
def testNgtNgtIndex(self):
109+
110+
# case 1: (-len(componentRun), -len(componentRun) + 1) return first componentRun)
111+
resCrList = self.new_component.history.get_runs_by_index(-self.historyLength, -self.historyLength + 1)
112+
self.assertEqual(len(resCrList), 1)
113+
self.assertEqual(isEqualComponentRun(resCrList[0], self.firstComponentRun))
114+
115+
# case 2: cannot retrieve last componentRun this way
116+
117+
# case 3: (-len(componentRun) + 1, -len(componentRun) + 3) return second and third componentRun given len(componentRun) >=4)
118+
resCrList = self.new_component.history.get_runs_by_index(-self.historyLength + 1, -self.historyLength + 3)
119+
self.assertEqual(len(resCrList), 2)
120+
for idx, cr in enumerate(resCrList):
121+
self.assertEqual(isEqualComponentRun(cr, self.secondAndThirdComponentRun[idx]))

0 commit comments

Comments
 (0)