# PostRegression Workflow
This notebook demonstrates the end-to-end EnerGaze workflow for the UMP 4.4.8 PostRegression model, including EDA, diagnostics, and model selection hooks.

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from energaze.models import PostRegression, TWFE
from energaze.visualization import build_load_shape_chart, build_treatment_balance_chart
from energaze.reporting.formatter import ReportFormatter
from energaze.orchestration.registry import ModelRegistry

In [2]:
DATA_PATH = Path('examples/data/post_regression_demo.csv')
if DATA_PATH.exists():
    df = pd.read_csv(DATA_PATH)
else:
    dates = pd.date_range('2022-01-01', '2023-12-31', freq='M')
    records = []
    states = ['CA', 'TX', 'NY']
    waves = ['pilot', 'scale']
    for site_id in range(80):
        treated = np.random.binomial(1, 0.5)
        state = np.random.choice(states)
        wave = np.random.choice(waves)
        baseline = np.random.normal(900, 120)
        for date in dates:
            post = int(date >= pd.Timestamp('2023-01-01'))
            shift = -80 if treated and post else 0
            consumption = baseline + np.random.normal(0, 60) + shift
            records.append({
                'date': date,
                'site_id': f'site_{site_id:03d}',
                'consumption': consumption,
                'treated': treated,
                'state': state,
                'wave': wave,
                'post_treatment': post
            })
    df = pd.DataFrame(records)
df['date'] = pd.to_datetime(df['date'])
df.head()

Unnamed: 0,date,site_id,consumption,treated,state,wave,post_treatment
0,2022-01-31,site_000,898.684496,0,NY,scale,0
1,2022-02-28,site_000,860.335713,0,NY,scale,0
2,2022-03-31,site_000,838.335201,0,NY,scale,0
3,2022-04-30,site_000,843.422817,0,NY,scale,0
4,2022-05-31,site_000,850.038984,0,NY,scale,0


## Exploratory Data Analysis

In [3]:
load_fig = build_load_shape_chart(df, time_col='date', value_col='consumption', color_col='state')
load_fig

In [4]:
balance_fig = build_treatment_balance_chart(df, treatment_col='treated')
balance_fig

DuplicateError: Expected unique column names, got:
- 'count' 2 times

## Fit PostRegression

In [None]:
post_model = PostRegression(
    data=df,
    consumption_var='consumption',
    treatment_var='treated',
    time_var='date',
    site_var='site_id',
    state_var='state',
    wave_var='wave',
    treatment_date='2023-01-01'
)
post_model.fit()
post_result = post_model.get_treatment_effect()
post_result

## Model Selection Example

In [None]:
twfe_model = TWFE(
    data=df,
    consumption_var='consumption',
    treatment_var='treated',
    time_var='date',
    site_var='site_id',
    state_var='state',
    wave_var='wave',
    treatment_date='2023-01-01'
)
twfe_model.fit()
registry = ModelRegistry()
registry.register(name='PostRegression', model=post_model)
registry.register(name='TWFE', model=twfe_model)
comparison = registry.build_comparison().to_dataframe()
comparison

## Reporting

In [None]:
formatter = ReportFormatter()
formatter.format_effects(post_model.iter_results(), model_name='PostRegression')