본문 바로가기
Python, API

[Python] SuperTrend

by 오늘밤날다 2025. 5. 18.

 

인터넷에 돌아다니는 코드를 조금 바꿨다.

데이터의 상황에 따라서는 compute_supretrend 함수를 numba로 처리하는 것이 더 빠를 수 있다.

 

    def supertrend(df, atr_period=14, atr_multiplier=3.0, source='HL'):
        prev_close = df['close'].shift(1)
        tr_a = df['high'] - df['low']
        tr_b = abs(df['high'] - prev_close)
        tr_c = abs(df['low'] - prev_close)
        tr_temp = np.maximum(tr_a, tr_b)
        df['tr'] = np.maximum(tr_temp, tr_c)
        df['atr_st'] = df['tr'].rolling(atr_period).mean()

        if source == 'HL':
            target_price = (df['high'] + df['low']) / 2
        elif source == 'HLC':
            target_price = (df['high'] + df['low'] + df['close']) / 3
        elif source == 'C':
            target_price = df['close']

        basic_upper = target_price + atr_multiplier * df['atr_st']
        basic_lower = target_price - atr_multiplier * df['atr_st']

        def compute_supertrend(close, basic_upper, basic_lower):
            n = len(close)
            upper = np.empty(n)
            lower = np.empty(n)

            for i in range(1, n):
                if not np.isnan(basic_upper[i]) and not np.isnan(basic_lower[i]):
                    upper[i] = basic_upper[i] if (basic_upper[i] < upper[i - 1] or close[i - 1] > upper[i - 1]) else upper[i - 1]
                    lower[i] = basic_lower[i] if (basic_lower[i] > lower[i - 1] or close[i - 1] < lower[i - 1]) else lower[i - 1]
                else:
                    upper[i] = 0.0
                    lower[i] = 0.0
            return upper, lower

        df['stub'], df['stlb'] = compute_supertrend(df['close'].values, basic_upper.values, basic_lower.values)
        return df