Portfolio class tracking positions, cash, equity, and total value.
Source code in Backtesting/portfolio.py
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74 | class portfolio:
"""Portfolio class tracking positions, cash, equity, and total value."""
positions: dict[str, int] #key is stock, value is num_shares
cash: float
equity: float
total: float
def __init__(self, positions: dict, cash: float, time_step: int):
"""Initialize portfolio with starting positions and cash.
Args:
positions: Dictionary mapping stock symbols to share counts.
cash: Starting cash amount.
time_step: Initial time step (unused, kept for consistency).
"""
self.positions = positions
self.cash = cash
self.equity = 0
self.total = cash + self.equity
def set_portfolio(self, time_step: int, output: list, data_obj: DataAPI):
"""Update portfolio with trades and recalculate equity at current prices.
Args:
time_step: Current time step for price lookup.
output: List of trade tuples (stock, shares, flag, price).
data_obj: DataAPI instance for price retrieval.
Raises:
InvalidPortfolioError: If trade would result in invalid portfolio state.
"""
for trade in output:
stock = trade[0]
num_shares = trade[1]
flag = trade[2]
stock_price = trade[3]
transaction_cost = num_shares * stock_price
if flag == 1:
if transaction_cost > self.cash:
raise InvalidPortfolioError("buy > cash")
elif flag == -1:
current_shares = self.positions.get(stock, 0)
if num_shares > current_shares:
raise InvalidPortfolioError("sell > shares")
self.positions[stock] = self.positions.get(stock, 0) + (flag*num_shares)
if self.positions[stock] == 0: #delete stocks from portfolio if no shares
del self.positions[stock]
elif self.positions[stock] < 0:
raise InvalidPortfolioError("Num_shares cannot be < 0")
self.cash -= flag * transaction_cost
self.equity = 0 #have to recalculate equity at each time step
for stock in self.positions:
stock_price = data_obj.get_price(stock, time_step)
self.equity += self.positions[stock] * stock_price
if self.positions[stock] < 0:
raise InvalidPortfolioError("Num_shares cannot be < 0")
if self.equity < 0 or self.cash < 0:
raise InvalidPortfolioError("cash or equity < 0")
self.total = self.cash + self.equity
|
__init__(positions, cash, time_step)
Initialize portfolio with starting positions and cash.
| Parameters: |
-
positions
(dict)
–
Dictionary mapping stock symbols to share counts.
-
cash
(float)
–
-
time_step
(int)
–
Initial time step (unused, kept for consistency).
|
Source code in Backtesting/portfolio.py
15
16
17
18
19
20
21
22
23
24
25
26 | def __init__(self, positions: dict, cash: float, time_step: int):
"""Initialize portfolio with starting positions and cash.
Args:
positions: Dictionary mapping stock symbols to share counts.
cash: Starting cash amount.
time_step: Initial time step (unused, kept for consistency).
"""
self.positions = positions
self.cash = cash
self.equity = 0
self.total = cash + self.equity
|
set_portfolio(time_step, output, data_obj)
Update portfolio with trades and recalculate equity at current prices.
| Parameters: |
-
time_step
(int)
–
Current time step for price lookup.
-
output
(list)
–
List of trade tuples (stock, shares, flag, price).
-
data_obj
(DataAPI)
–
DataAPI instance for price retrieval.
|
Source code in Backtesting/portfolio.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74 | def set_portfolio(self, time_step: int, output: list, data_obj: DataAPI):
"""Update portfolio with trades and recalculate equity at current prices.
Args:
time_step: Current time step for price lookup.
output: List of trade tuples (stock, shares, flag, price).
data_obj: DataAPI instance for price retrieval.
Raises:
InvalidPortfolioError: If trade would result in invalid portfolio state.
"""
for trade in output:
stock = trade[0]
num_shares = trade[1]
flag = trade[2]
stock_price = trade[3]
transaction_cost = num_shares * stock_price
if flag == 1:
if transaction_cost > self.cash:
raise InvalidPortfolioError("buy > cash")
elif flag == -1:
current_shares = self.positions.get(stock, 0)
if num_shares > current_shares:
raise InvalidPortfolioError("sell > shares")
self.positions[stock] = self.positions.get(stock, 0) + (flag*num_shares)
if self.positions[stock] == 0: #delete stocks from portfolio if no shares
del self.positions[stock]
elif self.positions[stock] < 0:
raise InvalidPortfolioError("Num_shares cannot be < 0")
self.cash -= flag * transaction_cost
self.equity = 0 #have to recalculate equity at each time step
for stock in self.positions:
stock_price = data_obj.get_price(stock, time_step)
self.equity += self.positions[stock] * stock_price
if self.positions[stock] < 0:
raise InvalidPortfolioError("Num_shares cannot be < 0")
if self.equity < 0 or self.cash < 0:
raise InvalidPortfolioError("cash or equity < 0")
self.total = self.cash + self.equity
|