
class Portfolio:
    """Trading portfolio management class."""
    DEFAULT_PERCENT = 1.0
    MIN_AMOUNT = 500
    def __init__(
        self, cash, single=True, warn=False, allow_fractional=False, fees=0.0015
    ):
        self.warn = warn
        assert cash >= 10000, "Minimum cash requirement: 10,000"
        self.starting_capital = cash
        self.cash = cash
        self.positions = {}
        self.fees = fees
        self.max_date = None
        self.trades = []
        self.pos_history = {}
        self.single = single
        self.events = defaultdict(list)
        self.total_fees_paid = 0.0
        self.allow_fractional = allow_fractional

    def _normalize_date(self, date_obj):
        """Normalize date object to datetime/date."""
        if isinstance(date_obj, (date, datetime)):
            return date_obj
        return pd.to_datetime(date_obj)

    def _validate_and_update_date(self, date_obj):
        """Validate date format and update max_date."""
        date_obj = self._normalize_date(date_obj)
        assert isinstance(date_obj, (date, datetime)), (
            "date_obj must be date or datetime"
        )

        if self.max_date and date_obj < self.max_date:
            print(f"! Recived date: {date_obj}, self.max_date: {self.max_date}")
            raise ValueError("Date not in chronological order")

        if not self.max_date or date_obj > self.max_date:
            self.max_date = date_obj

        return date_obj

    def buy(self, ticker, date_obj, price, fixed_val=None):
        """Buy shares of a ticker."""
        if fixed_val is None:
            if not self.single:
                raise AssertionError(
                    "For multi ticker portfolio fixed_val buy value is expected"
                )

        date_obj = self._validate_and_update_date(date_obj)
        price = float(price)

        if self.single and self.positions and ticker not in self.positions:
            raise Exception(f"Single mode - can't add new ticker: {ticker}")

        if ticker in self.positions:
            if self.warn:
                print(f"! Re-buying not allowed: {ticker} {date_obj}")
            return

        if fixed_val:
            assert price < fixed_val
            assert fixed_val >= self.MIN_AMOUNT, f"fixed_val amount must be >= {
                self.MIN_AMOUNT
            }"
            if self.cash < fixed_val:
                msg = f"Insufficient cash: {self.cash:.1f} < {fixed_val}"
                self.events[date_obj].append(msg)
                raise ValueError(msg)
            available_investment = fixed_val
        else:
            available_investment = self.cash * self.DEFAULT_PERCENT #1.0

        buffer_factor = 0.99999  # Prevent error comparing floats
        max_affordable_shares = (available_investment * buffer_factor) / (
            price * (1 + self.fees)
        )

        quantity = (
            max_affordable_shares
            if self.allow_fractional
            else int(max_affordable_shares)
        )

        if quantity <= 0.001:
            msg = f"""! buy quantity <= 0.001 for {ticker} at {price}: need {
                price * (1 + self.fees)},
             have {available_investment:.2f}"""
            self.events[date_obj].append(msg)
            if self.warn:
                print(msg)
            return

        share_cost = price * quantity
        fees = share_cost * self.fees
        total_cost = share_cost + fees

        if self.cash < total_cost:
            if self.warn:
                raise ValueError(
                    f"Calculation error: cash {self.cash:.5f} < needed {total_cost:.5f}"
                )

        self.cash -= total_cost
        self.total_fees_paid += fees
        self.positions[ticker] = [date_obj, quantity, price, date_obj, price]

        self.events[date_obj].append(
            f"buy: {ticker}({price:.2f}*{quantity:.1f}, cost: {total_cost:.1f}, fees: {fees:.1f})"
        )

        if self.warn:
            print(
                f"""Bought {quantity:.4f} shares of {ticker} at {price:.2f}, total cost: {
                    total_cost:.2f
                }"""
            )

        self.save_positions(date_obj)

    def sell(self, ticker, date_obj, price, log_msg=""):
        """Sell shares of a ticker."""
        date_obj = self._validate_and_update_date(date_obj)
        price = float(price)

        if ticker not in self.positions:
            if self.warn:
                print(f"Ticker not in portfolio: {ticker}")
            return

        buy_date, quantity, buy_price, last_date, last_price = self.positions.pop(
            ticker
        )

        gross_proceeds = price * quantity
        fees = gross_proceeds * self.fees
        net_proceeds = gross_proceeds - fees

        self.trades.append(
            TradeData(ticker, buy_date, buy_price, date_obj, price, quantity, log_msg)
        )

        self.cash += net_proceeds
        self.total_fees_paid += fees

        self.events[date_obj].append(
            f"sell: {ticker}({price:.2f}*{quantity:.1f}, proceeds: {net_proceeds:.1f}, fees: {fees:.1f})"
        )

        if self.warn:
            print(
                f"""Sold {quantity:.4f} shares of {ticker} at {price:.2f}, net proceeds: {
                    net_proceeds:.2f
                }"""
            )

        self.save_positions(date_obj)

    def update(self, ticker, date_obj, price):
        """Update price for existing position."""
        date_obj = self._validate_and_update_date(date_obj)
        price = float(price)
        assert price > 0, "Price must be positive"

        if ticker in self.positions:
            self.positions[ticker][-2] = date_obj
            self.positions[ticker][-1] = price

        self.save_positions(date_obj)

    def process(self, signal, ticker, date_obj, price, buy_fixed=None, log_msg=""):
        """Unified entry-point for any daily signal."""
        if signal is None:
            self.update(ticker, date_obj, price)
        elif signal == "buy":
            self.buy(ticker, date_obj, price, fixed_val=buy_fixed)
        elif signal == "sell":
            self.sell(ticker, date_obj, price, log_msg=log_msg)
        else:
            raise ValueError(f"Unknown signal {signal!r}")

    def save_positions(self, record_date):
        """Save current portfolio state."""
        pos_list = []
        for ticker, (
            buy_date,
            qty,
            buy_price,
            last_date,
            last_price,
        ) in self.positions.items():
            pos_list.append((ticker, qty, last_price))
        self.pos_history[record_date] = (self.cash, pos_list)

    def get_open_positions(self):
        """Get information about open positions."""
        if not self.positions:
            return {"positions_total": 0, "open_trades": []}

        open_trades = []
        total = 0

        for ticker, (
            buy_date,
            quantity,
            buy_price,
            last_date,
            last_price,
        ) in self.positions.items():
            if not last_date:
                last_date = self.max_date or datetime.now()
            if not last_price:
                last_price = buy_price

            trade_data = TradeData(
                ticker=ticker,
                buy_date=buy_date,
                buy_price=buy_price,
                sell_date=last_date,
                sell_price=last_price,
                quantity=quantity,
                log_msg="open",
            )
            open_trades.append(trade_result(trade_data, self.fees, opened=True))
            total += last_price * quantity

        open_trades.sort(key=lambda x: x.buy_date)
        return {"positions_total": total, "open_trades": open_trades}

    def current_value(self):
        """Calculate current total portfolio value."""
        total = self.cash
        if self.positions:
            positions_value = self.get_open_positions()["positions_total"]
            total += positions_value
            if self.warn:
                print(
                    f"""Cash: {self.cash:.2f}, Positions: {positions_value:.2f}, Total: {
                        total:.2f
                    }"""
                )
        return total

    def get_tradelist(self):
        """Returns DataFrame of all trades."""
        closed_trades = [trade_result(trade, self.fees) for trade in self.trades]
        open_trades = self.get_open_positions()["open_trades"]

        all_trades = closed_trades + open_trades
        df = pd.DataFrame(all_trades, columns=TradeResult._fields)
        df = df.sort_values("sell_date").reset_index(drop=True)
        return df

    def has_position(self, ticker):
        """Check if ticker is in current positions."""
        return ticker in self.positions

    def __repr__(self):
        """String representation of portfolio."""
        print("---- PORTFOLIO ----")
        print(f"Total Value: {self.current_value():.2f}")
        print(f"Cash: {self.cash:.2f}")
        print(f"Positions Value: {self.get_open_positions()['positions_total']:.2f}")
        print(f"Total fees Paid: {self.total_fees_paid:.2f}")
        print(
            f"Return: {((self.current_value() / self.starting_capital - 1) * 100):.2f}%"
        )
        pprint(self.positions)
        return f"Last Date: {self.max_date}"