# -*- coding: utf-8 -*-

"""
torch.compile Tutorial
================
**Author:** William Wen
"""

######################################################################
# ``torch.compile`` is the latest method to speed up your PyTorch code!
# ``torch.compile`` makes PyTorch code run faster by
# JIT-compiling PyTorch code into optimized kernels,
# all while requiring minimal code changes.
# 
# In this tutorial, we cover basic ``torch.compile`` usage,
# and demonstrate the advantages of ``torch.compile`` over
# previous PyTorch compiler solutions, such as
# `TorchScript <https://pytorch.org/docs/stable/jit.html>`__ and 
# `FX Tracing <https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace>`__.
#
# **Contents**
# 
# - Basic Usage
# - Demonstrating Speedups
# - Comparison to TorchScript and FX Tracing
# - TorchDynamo and FX Graphs
# - Conclusion
#
# **Required pip Dependencies**
#
# - ``torch >= 2.0``
# - ``torchvision``
# - ``numpy``
# - ``scipy``
# - ``tabulate``

######################################################################
# NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in
# order to reproduce the speedup numbers shown below and documented elsewhere.

import torch
import warnings

gpu_ok = False
if torch.cuda.is_available():
    device_cap = torch.cuda.get_device_capability()
    if device_cap in ((7, 0), (8, 0), (9, 0)):
        gpu_ok = True

if not gpu_ok:
    warnings.warn(
        "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
        "than expected."
    )

######################################################################
# Basic Usage
# ------------
#
# ``torch.compile`` is included in the latest PyTorch..
# Running TorchInductor on GPU requires Triton, which is included with the PyTorch 2.0 nightly
# binary. If Triton is still missing, try installing ``torchtriton`` via pip 
# (``pip install torchtriton --extra-index-url "https://download.pytorch.org/whl/nightly/cu117"``
# for CUDA 11.7).
#
# Arbitrary Python functions can be optimized by passing the callable to
# ``torch.compile``. We can then call the returned optimized
# function in place of the original function.

def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))

######################################################################
# Alternatively, we can decorate the function.

@torch.compile
def opt_foo2(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b
print(opt_foo2(torch.randn(10, 10), torch.randn(10, 10)))

######################################################################
# We can also optimize ``torch.nn.Module`` instances.

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x):
        return torch.nn.functional.relu(self.lin(x))

mod = MyModule()
opt_mod = torch.compile(mod)
print(opt_mod(torch.randn(10, 100)))

######################################################################
# Demonstrating Speedups
# -----------------------
#
# Let's now demonstrate that using ``torch.compile`` can speed
# up real models. We will compare standard eager mode and 
# ``torch.compile`` by evaluating and training a ``torchvision`` model on random data.
#
# Before we start, we need to define some utility functions.

# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
        torch.randint(1000, (b,)).cuda(),
    )

N_ITERS = 10

from torchvision.models import densenet121
def init_model():
    return densenet121().to(torch.float32).cuda()

######################################################################
# First, let's compare inference.
#
# Note that in the call to ``torch.compile``, we have have the additional
# ``mode`` argument, which we will discuss below.

def evaluate(mod, inp):
    with torch.no_grad():
        return mod(inp)

model = init_model()

# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()

evaluate_opt = torch.compile(evaluate, mode="reduce-overhead")

inp = generate_data(16)[0]
print("eager:", timed(lambda: evaluate(model, inp))[1])
print("compile:", timed(lambda: evaluate_opt(model, inp))[1])

######################################################################
# Notice that ``torch.compile`` takes a lot longer to complete
# compared to eager. This is because ``torch.compile`` compiles
# the model into optimized kernels as it executes. In our example, the
# structure of the model doesn't change, and so recompilation is not
# needed. So if we run our optimized model several more times, we should
# see a significant improvement compared to eager.

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    _, eager_time = timed(lambda: evaluate(model, inp))
    eager_times.append(eager_time)
    print(f"eager eval time {i}: {eager_time}")

print("~" * 10)

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    _, compile_time = timed(lambda: evaluate_opt(model, inp))
    compile_times.append(compile_time)
    print(f"compile eval time {i}: {compile_time}")
print("~" * 10)

import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

######################################################################
# And indeed, we can see that running our model with ``torch.compile``
# results in a significant speedup. Speedup mainly comes from reducing Python overhead and
# GPU read/writes, and so the observed speedup may vary on factors such as model
# architecture and batch size. For example, if a model's architecture is simple
# and the amount of data is large, then the bottleneck would be
# GPU compute and the observed speedup may be less significant.
#
# You may also see different speedup results depending on the chosen ``mode``
# argument. Since our model and data are small, we want to reduce overhead as
# much as possible, and so we chose ``"reduce-overhead"``. For your own models,
# you may need to experiment with different modes to maximize speedup. You can
# read more about modes `here <https://pytorch.org/get-started/pytorch-2.0/#user-experience>`__.
#
# For general PyTorch benchmarking, you can try using ``torch.utils.benchmark`` instead of the ``timed``
# function we defined above. We wrote our own timing function in this tutorial to show
# ``torch.compile``'s compilation latency.
#
# Now, let's consider comparing training.

model = init_model()
opt = torch.optim.Adam(model.parameters())

def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, eager_time = timed(lambda: train(model, inp))
    eager_times.append(eager_time)
    print(f"eager train time {i}: {eager_time}")
print("~" * 10)

model = init_model()
opt = torch.optim.Adam(model.parameters())
train_opt = torch.compile(train, mode="reduce-overhead")

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, compile_time = timed(lambda: train_opt(model, inp))
    compile_times.append(compile_time)
    print(f"compile train time {i}: {compile_time}")
print("~" * 10)

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

######################################################################
# Again, we can see that ``torch.compile`` takes longer in the first
# iteration, as it must compile the model, but in subsequent iterations, we see
# significant speedups compared to eager.

######################################################################
# Comparison to TorchScript and FX Tracing
# -----------------------------------------
# 
# We have seen that ``torch.compile`` can speed up PyTorch code.
# Why else should we use ``torch.compile`` over existing PyTorch
# compiler solutions, such as TorchScript or FX Tracing? Primarily, the
# advantage of ``torch.compile`` lies in its ability to handle
# arbitrary Python code with minimal changes to existing code.
#
# One case that ``torch.compile`` can handle that other compiler
# solutions struggle with is data-dependent control flow (the 
# ``if x.sum() < 0:`` line below).

def f1(x, y):
    if x.sum() < 0:
        return -y
    return y

# Test that `fn1` and `fn2` return the same result, given
# the same arguments `args`. Typically, `fn1` will be an eager function
# while `fn2` will be a compiled function (torch.compile, TorchScript, or FX graph).
def test_fns(fn1, fn2, args):
    out1 = fn1(*args)
    out2 = fn2(*args)
    return torch.allclose(out1, out2)

inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)

######################################################################
# TorchScript tracing ``f1`` results in
# silently incorrect results, since only the actual control flow path
# is traced.

traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))

######################################################################
# FX tracing ``f1`` results in an error due to the presence of
# data-dependent control flow.

import traceback as tb
try:
    torch.fx.symbolic_trace(f1)
except:
    tb.print_exc()

######################################################################
# If we provide a value for ``x`` as we try to FX trace ``f1``, then
# we run into the same problem as TorchScript tracing, as the data-dependent
# control flow is removed in the traced function.

fx_f1 = torch.fx.symbolic_trace(f1, concrete_args={"x": inp1})
print("fx 1, 1:", test_fns(f1, fx_f1, (inp1, inp2)))
print("fx 1, 2:", test_fns(f1, fx_f1, (-inp1, inp2)))

######################################################################
# Now we can see that ``torch.compile`` correctly handles
# data-dependent control flow.

# Reset since we are using a different mode.
torch._dynamo.reset()

compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)

######################################################################
# TorchScript scripting can handle data-dependent control flow, but this
# solution comes with its own set of problems. Namely, TorchScript scripting
# can require major code changes and will raise errors when unsupported Python
# is used.
#
# In the example below, we forget TorchScript type annotations and we receive
# a TorchScript error because the input type for argument ``y``, an ``int``,
# does not match with the default argument type, ``torch.Tensor``.

def f2(x, y):
    return x + y

inp1 = torch.randn(5, 5)
inp2 = 3

script_f2 = torch.jit.script(f2)
try:
    script_f2(inp1, inp2)
except:
    tb.print_exc()

######################################################################
# However, ``torch.compile`` is easily able to handle ``f2``.

compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)

######################################################################
# Another case that ``torch.compile`` handles well compared to
# previous compilers solutions is the usage of non-PyTorch functions.

import scipy
def f3(x):
    x = x * 2
    x = scipy.fft.dct(x.numpy())
    x = torch.from_numpy(x)
    x = x * 2
    return x

######################################################################
# TorchScript tracing treats results from non-PyTorch function calls
# as constants, and so our results can be silently wrong.

inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
traced_f3 = torch.jit.trace(f3, (inp1,))
print("traced 3:", test_fns(f3, traced_f3, (inp2,)))

######################################################################
# TorchScript scripting and FX tracing disallow non-PyTorch function calls.

try:
    torch.jit.script(f3)
except:
    tb.print_exc()

try:
    torch.fx.symbolic_trace(f3)
except:
    tb.print_exc()

######################################################################
# In comparison, ``torch.compile`` is easily able to handle
# the non-PyTorch function call.

compile_f3 = torch.compile(f3)
print("compile 3:", test_fns(f3, compile_f3, (inp2,)))

######################################################################
# TorchDynamo and FX Graphs
# --------------------------
#
# One important component of ``torch.compile`` is TorchDynamo.
# TorchDynamo is responsible for JIT compiling arbitrary Python code into
# `FX graphs <https://pytorch.org/docs/stable/fx.html#torch.fx.Graph>`__, which can
# then be further optimized. TorchDynamo extracts FX graphs by analyzing Python bytecode
# during runtime and detecting calls to PyTorch operations.
# 
# Normally, TorchInductor, another component of ``torch.compile``,
# further compiles the FX graphs into optimized kernels,
# but TorchDynamo allows for different backends to be used. In order to inspect
# the FX graphs that TorchDynamo outputs, let us create a custom backend that
# outputs the FX graph and simply returns the graph's unoptimized forward method.

from typing import List
def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print("custom backend called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward

# Reset since we are using a different backend.
torch._dynamo.reset()

opt_model = torch.compile(init_model(), backend=custom_backend)
opt_model(generate_data(16)[0])

######################################################################
# Using our custom backend, we can now see how TorchDynamo is able to handle
# data-dependent control flow. Consider the function below, where the line
# ``if b.sum() < 0`` is the source of data-dependent control flow.

def bar(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

opt_bar = torch.compile(bar, backend=custom_backend)
inp1 = torch.randn(10)
inp2 = torch.randn(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)

######################################################################
# The output reveals that TorchDynamo extracted 3 different FX graphs
# corresponding the following code (order may differ from the output above):
#
# 1. ``x = a / (torch.abs(a) + 1)``
# 2. ``b = b * -1; return x * b``
# 3. ``return x * b``
#
# When TorchDynamo encounters unsupported Python features, such as data-dependent
# control flow, it breaks the computation graph, lets the default Python
# interpreter handle the unsupported code, then resumes capturing the graph.
#
# Let's investigate by example how TorchDynamo would step through ``bar``.
# If ``b.sum() < 0``, then TorchDynamo would run graph 1, let
# Python determine the result of the conditional, then run
# graph 2. On the other hand, if ``not b.sum() < 0``, then TorchDynamo
# would run graph 1, let Python determine the result of the conditional, then
# run graph 3.
#
# This highlights a major difference between TorchDynamo and previous PyTorch
# compiler solutions. When encountering unsupported Python features,
# previous solutions either raise an error or silently fail.
# TorchDynamo, on the other hand, will break the computation graph.
#
# We can see where TorchDynamo breaks the graph by using ``torch._dynamo.explain``:

# Reset since we are using a different backend.
torch._dynamo.reset()
explanation, out_guards, graphs, ops_per_graph, break_reasons, explanation_verbose = torch._dynamo.explain(
    bar, torch.randn(10), torch.randn(10)
)
print(explanation_verbose)

######################################################################
# In order to maximize speedup, graph breaks should be limited.
# We can force TorchDynamo to raise an error upon the first graph
# break encountered by using ``fullgraph=True``:

opt_bar = torch.compile(bar, fullgraph=True)
try:
    opt_bar(torch.randn(10), torch.randn(10))
except:
    tb.print_exc()

######################################################################
# And below, we demonstrate that TorchDynamo does not break the graph on
# the model we used above for demonstrating speedups.

opt_model = torch.compile(init_model(), fullgraph=True)
print(opt_model(generate_data(16)[0]))

######################################################################
# Finally, if we simply want TorchDynamo to output the FX graph for export,
# we can use ``torch._dynamo.export``. Note that ``torch._dynamo.export``, like
# ``fullgraph=True``, raises an error if TorchDynamo breaks the graph.

try:
    torch._dynamo.export(bar, torch.randn(10), torch.randn(10))
except:
    tb.print_exc()

model_exp = torch._dynamo.export(init_model(), generate_data(16)[0])
print(model_exp[0](generate_data(16)[0]))

######################################################################
# Conclusion
# ------------
#
# In this tutorial, we introduced ``torch.compile`` by covering
# basic usage, demonstrating speedups over eager mode, comparing to previous
# PyTorch compiler solutions, and briefly investigating TorchDynamo and its interactions
# with FX graphs. We hope that you will give ``torch.compile`` a try!
