Commit graph

12 commits

Author SHA1 Message Date
Ayo
07672aaf2b refactor: remove unneeded torch.no_grad() 2025-09-03 17:06:11 +02:00
Ayo
1cb2a50ce5 feat: load to XPU 2025-09-03 09:02:23 +02:00
Ayo
aeb26374df feat: use same config as cuda for xpu 2025-09-02 23:26:22 +02:00
Ayo
66b32b2e4e feat: use xpu as default if it exists 2025-09-02 23:13:54 +02:00
gregory-fanous
6dd49cb720
feat: add Apple MPS device support (dtype + attention handling) to demos (#67)
* Enhance model loading with device support and error handling

Updated device handling for model loading and added support for MPS. Improved error handling and fallback mechanisms for attention implementations.

* Improve device handling and model loading logic

Updated device argument handling to support MPS and added validation for MPS availability. Enhanced model loading logic based on the selected device type.

* fallback only when flash_attention_2 and add some comments back

---------

Co-authored-by: YaoyaoChang <cyy574006791@qq.com>
2025-09-01 17:16:01 +08:00
YaoyaoChang
3074f898ca support CPU 2025-09-01 16:32:50 +08:00
YaoyaoChang
4b8b6f7700 update 2025-08-28 01:28:27 -07:00
_
e611deafac FIX: adjust quote type in inference_from_file (#33)
adds
full_script = full_script.replace("’", "'")
to the data preparation
2025-08-28 15:24:42 +08:00
YaoyaoChang
ac0104c65e try sdpa if error in flash_attention_2 2025-08-27 18:46:15 -07:00
YaoyaoChang
8ed05ba7de update to sdpa 2025-08-27 18:21:30 -07:00
YaoyaoChang
056bb5b0fa add args to use_eager 2025-08-26 19:44:34 -07:00
pengzhiliang
d5895e76d4 init 2025-08-25 15:28:13 +00:00