FROM rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0

# Set build architecture
ENV PYTORCH_ROCM_ARCH=gfx90a

# Clone PyTorch repository
RUN (! [ -e /tmp/build ] || rm -rf /tmp/build) \
    && mkdir -p /tmp/build \
    && cd /tmp/build \
    && git clone https://github.com/pytorch/pytorch.git \
    && cd pytorch \
    && git checkout v2.4.1 \
    && git submodule sync \
    && git submodule update --init --recursive

# Build PyTorch (n.b., should return 0 if successful)
RUN cd /tmp/build/pytorch && .ci/pytorch/build.sh && echo $?

# Clean up build files
RUN [ -e /tmp/build ] && rm -rf /tmp/build
