Update all-in-one benchmark prompts for continuation task & lookup update for minicpmv (#11827)

* Update all-in-one benchmark prompts for continuation task

* Small fix

* Add pure-text benchmark support for minicpm-v-2_6

* Support lookahead for model.llm generate of minicpmv

* Add prompt reference

* Small update

* Small fix
This commit is contained in:
Yuwen Hu 2024-08-16 17:16:35 +08:00 committed by GitHub
parent e966e85df8
commit 96796f95cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 803 additions and 17 deletions

View file

@ -4,6 +4,8 @@ All in one benchmark test allows users to test all the benchmarks and record the
Before running, make sure to have [ipex-llm](../../../../../README.md) installed.
> The prompts for benchmarking are from datasets [abisee/cnn_dailymail](https://huggingface.co/datasets/abisee/cnn_dailymail), [Open-Orca/OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca), [THUDM/LongBench](https://huggingface.co/datasets/THUDM/LongBench), etc.
## Dependencies
```bash

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,760 @@
Please complete the code given below.
app/src/main/java/com/matejdro/pebbledialer/callactions/EndCallAction.java
public class EndCallAction extends CallAction
{
public static final int END_CALL_ACTION_ID = 1;
private PendingIntent notificationEndCallIntent;
private static Method getITelephonyMethod;
public EndCallAction(CallModule callModule)
{
super(callModule);
try {
getITelephonyMethod = TelephonyManager.class.getDeclaredMethod("getITelephony", (Class[]) null);
getITelephonyMethod.setAccessible(true);
} catch (NoSuchMethodException e) {
Timber.e(e, "iTelephony end not supported on your phone!");
} catch (Exception e) {
Timber.e(e, "Error while acquiring iTelephony");
Crashlytics.logException(e);
}
}
public void registerNotificationEndCallIntent(PendingIntent notificationAnswerIntent)
{
this.notificationEndCallIntent = notificationAnswerIntent;
}
@Override
public void executeAction()
{
getCallModule().setCloseAutomaticallyAfterThisCall(true);
if (getCallModule().getService().getGlobalSettings().getBoolean("rootMode", false))
{
Timber.d("Ending call using root method...");
try {
Runtime.getRuntime().exec(new String[] {"su", "-c", "input keyevent 6"});
return;
} catch (IOException e) {
e.printStackTrace();
}
}
if (getCallModule().getCallState() == CallModule.CallState.RINGING && notificationEndCallIntent != null)
{
Timber.d("Ending call using notification method...");
try {
notificationEndCallIntent.send();
return;
} catch (PendingIntent.CanceledException e) {
}
}
if (getITelephonyMethod != null)
{
Timber.d("Ending call using generic iTelephony method...");
try
{
ITelephony iTelephony = (ITelephony) getITelephonyMethod.invoke(getCallModule().getService().getSystemService(Context.TELEPHONY_SERVICE), (Object[]) null);
iTelephony.endCall();
return;
}
catch (SecurityException e)
{
Timber.e("Cannot decline call, no CALL_PHONE permission.");
}
catch (Exception e) {
Timber.e(e, "Error while invoking iTelephony.endCall()");
Crashlytics.logException(e);
}
}
Timber.e("All end call options failed! Nothing is supported.");
}
@Override
public void onCallEnd()
{
notificationEndCallIntent = null; //Reset intent (there will be new intent at next call)
}
@Override
public int getIcon()
{
return CallAction.ICON_BUTTON_END_CALL;
}
public static EndCallAction get(CallModule callModule)
{
return (EndCallAction) callModule.getCallAction(END_CALL_ACTION_ID);
}
}
app/src/main/java/com/matejdro/pebbledialer/callactions/AnswerCallAction.java
public class AnswerCallAction extends CallAction
{
public static final int ANSWER_ACTION_ID = 0;
private PendingIntent notificationAnswerIntent;
public AnswerCallAction(CallModule callModule)
{
super(callModule);
}
public void registerNotificationAnswerIntent(PendingIntent notificationAnswerIntent)
{
this.notificationAnswerIntent = notificationAnswerIntent;
}
@Override
public void executeAction()
{
if (getCallModule().getCallState() != CallModule.CallState.RINGING)
return;
if (getCallModule().getService().getGlobalSettings().getBoolean("rootMode", false))
{
Timber.d("Answering using root method...");
try {
Runtime.getRuntime().exec(new String[] {"su", "-c", "input keyevent 5"});
return;
} catch (IOException e) {
e.printStackTrace();
}
}
if (notificationAnswerIntent != null)
{
Timber.d("Answering using notification method...");
try {
notificationAnswerIntent.send();
return;
} catch (PendingIntent.CanceledException e) {
}
}
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O)
{
answerNativelyOreo();
}
else if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.LOLLIPOP)
{
answerUsingMediaServer();
}
else
{
Timber.d("Answering using generic headset hook method...");
Intent buttonUp = new Intent(Intent.ACTION_MEDIA_BUTTON);
buttonUp.putExtra(Intent.EXTRA_KEY_EVENT, new KeyEvent(KeyEvent.ACTION_UP, KeyEvent.KEYCODE_HEADSETHOOK));
getCallModule().getService().sendOrderedBroadcast(buttonUp, "android.permission.CALL_PRIVILEGED");
}
}
@TargetApi(Build.VERSION_CODES.O)
private void answerNativelyOreo() {
TelecomManager telecomManager
= (TelecomManager) getCallModule().getService().getSystemService(Context.TELECOM_SERVICE);
Timber.d("Answering natively with Oreo.");
try {
telecomManager.acceptRingingCall();
} catch (SecurityException e) {
Timber.e("No accept call permission!");
}
}
@TargetApi(Build.VERSION_CODES.LOLLIPOP)
private void answerUsingMediaServer()
{
Timber.d("Answering using media server method...");
MediaSessionManager mediaSessionManager = (MediaSessionManager) getCallModule().getService().getSystemService(Context.MEDIA_SESSION_SERVICE);
try {
List<MediaController> mediaControllerList = mediaSessionManager.getActiveSessions
(new ComponentName(getCallModule().getService(), JellybeanNotificationListener.class));
for (MediaController m : mediaControllerList) {
if ("com.android.server.telecom".equals(m.getPackageName())) {
Timber.d("Found telephony media controller!");
m.dispatchMediaButtonEvent(new KeyEvent(KeyEvent.ACTION_UP, KeyEvent.KEYCODE_HEADSETHOOK));
break;
}
}
} catch (SecurityException e) {
Timber.e("Notification service not running!");
}
}
@Override
public void onCallEnd()
{
notificationAnswerIntent = null; //Reset intent (there will be new intent at next call)
}
@Override
public int getIcon()
{
return CallAction.ICON_BUTTON_ANSWER;
}
public static AnswerCallAction get(CallModule callModule)
{
return (AnswerCallAction) callModule.getCallAction(ANSWER_ACTION_ID);
}
}
app/src/main/java/com/matejdro/pebbledialer/callactions/SMSReplyAction.java
public class SMSReplyAction extends CallAction
{
public static final int SMS_REPLY_ACTION_ID = 6;
public SMSReplyAction(CallModule callModule)
{
super(callModule);
}
@Override
public void executeAction()
{
ToggleRingerAction toggleRingerAction = ToggleRingerAction.get(getCallModule());
toggleRingerAction.mute();
SMSReplyModule smsReplyModule = SMSReplyModule.get(getCallModule().getService());
smsReplyModule.startSMSProcess(getCallModule().getNumber());
getCallModule().setCloseAutomaticallyAfterThisCall(false);
}
@Override
public void onCallEnd()
{
}
@Override
public int getIcon()
{
return CallAction.ICON_BUTTON_END_CALL;
}
public static SMSReplyAction get(CallModule callModule)
{
return (SMSReplyAction) callModule.getCallAction(SMS_REPLY_ACTION_ID);
}
}
app/src/main/java/com/matejdro/pebbledialer/callactions/ToggleMicrophoneAction.java
public class ToggleMicrophoneAction extends CallAction
{
public static final int TOGGLE_MICROPHONE_ACTION_ID = 3;
private boolean microphoneMuted = false;
public ToggleMicrophoneAction(CallModule callModule)
{
super(callModule);
}
@Override
public void executeAction()
{
if (getCallModule().getCallState() != CallModule.CallState.ESTABLISHED)
return;
microphoneMuted = !microphoneMuted;
if (getCallModule().getService().getGlobalSettings().getBoolean("rootMode", false))
{
String muteCommand;
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.LOLLIPOP)
muteCommand = "input keyevent 79";
else
muteCommand = "input keyevent 91";
try {
Runtime.getRuntime().exec(new String[] {"su", "-c", muteCommand});
} catch (IOException e) {
e.printStackTrace();
}
}
else
{
AudioManager audioManager = (AudioManager) getCallModule().getService().getSystemService(Context.AUDIO_SERVICE);
audioManager.setMicrophoneMute(microphoneMuted);
}
getCallModule().updatePebble();
}
@Override
public int getIcon()
{
return microphoneMuted ? CallAction.ICON_BUTTON_MIC_OFF : CallAction.ICON_BUTTON_MIC_ON;
}
public static ToggleMicrophoneAction get(CallModule callModule)
{
return (ToggleMicrophoneAction) callModule.getCallAction(TOGGLE_MICROPHONE_ACTION_ID);
}
}
app/src/main/java/com/matejdro/pebbledialer/callactions/ToggleRingerAction.java
public class ToggleRingerAction extends CallAction
{
public static final int TOGGLE_RINGER_ACTION_ID = 2;
private boolean isMutedViaAudioManager = false;
private int prevRingerMode = AudioManager.RINGER_MODE_NORMAL;
public ToggleRingerAction(CallModule callModule)
{
super(callModule);
}
@Override
public void executeAction()
{
if (getCallModule().getCallState() != CallModule.CallState.RINGING)
return;
AudioManager audioManager = (AudioManager) getCallModule().getService().getSystemService(Context.AUDIO_SERVICE);
getCallModule().setVibration(false);
if (!isMutedViaAudioManager)
{
if (getCallModule().getService().getGlobalSettings().getBoolean("rootMode", false))
{
Timber.d("Muting using root method...");
try {
Runtime.getRuntime().exec(new String[] {"su", "-c", "input keyevent " + KeyEvent.KEYCODE_VOLUME_DOWN});
} catch (IOException e) {
e.printStackTrace();
}
}
else if (canMuteRinger(getCallModule().getService()))
{
isMutedViaAudioManager = true;
prevRingerMode = audioManager.getRingerMode();
audioManager.setStreamSolo(AudioManager.STREAM_MUSIC, true);
audioManager.setRingerMode(AudioManager.RINGER_MODE_SILENT);
}
}
else if (canMuteRinger(getCallModule().getService()))
{
isMutedViaAudioManager = false;
audioManager.setStreamSolo(AudioManager.STREAM_MUSIC, false);
audioManager.setRingerMode(prevRingerMode);
}
getCallModule().updatePebble();
}
public void mute()
{
if (!isMutedViaAudioManager)
executeAction();
}
public static boolean canMuteRinger(Context context)
{
if (Build.VERSION.SDK_INT < Build.VERSION_CODES.M)
return true;
NotificationManager notificationManager = (NotificationManager) context.getSystemService(Context.NOTIFICATION_SERVICE);
return notificationManager.isNotificationPolicyAccessGranted();
}
@Override
public void onCallEnd()
{ if (isMutedViaAudioManager && canMuteRinger(getCallModule().getService()))
{
AudioManager audioManager = (AudioManager) getCallModule().getService().getSystemService(Context.AUDIO_SERVICE);
isMutedViaAudioManager = false;
audioManager.setStreamSolo(AudioManager.STREAM_MUSIC, false);
audioManager.setRingerMode(prevRingerMode);
}
getCallModule().setVibration(true);
}
@Override
public int getIcon()
{
return isMutedViaAudioManager ? CallAction.ICON_BUTTON_SPEAKER_OFF : CallAction.ICON_BUTTON_SPEKAER_ON;
}
public static ToggleRingerAction get(CallModule callModule)
{
return (ToggleRingerAction) callModule.getCallAction(TOGGLE_RINGER_ACTION_ID);
}
}
app/src/main/java/com/matejdro/pebbledialer/callactions/AnswerCallWithSpeakerAction.java
public class AnswerCallWithSpeakerAction extends CallAction
{
public static final int ANSWER_WITH_SPEAKER_ACTION_ID = 5;
private boolean enableSpeakerImmediately = false;
public AnswerCallWithSpeakerAction(CallModule callModule)
{
super(callModule);
}
@Override
public void executeAction()
{
if (getCallModule().getCallState() != CallModule.CallState.RINGING)
return;
enableSpeakerImmediately = true;
AnswerCallAction.get(getCallModule()).executeAction();
}
@Override
public void onCallEnd()
{
enableSpeakerImmediately = false; //Reset intent (there will be new intent at next call)
}
@Override
public void onPhoneOffhook()
{
if (enableSpeakerImmediately)
{
ToggleSpeakerAction speakerAction = ToggleSpeakerAction.get(getCallModule());
if (!speakerAction.isSpeakerphoneEnabled())
speakerAction.executeAction();
}
}
@Override
public int getIcon()
{
return CallAction.ICON_BUTTON_ANSWER;
}
public static AnswerCallWithSpeakerAction get(CallModule callModule)
{
return (AnswerCallWithSpeakerAction) callModule.getCallAction(ANSWER_WITH_SPEAKER_ACTION_ID);
}
}
app/src/main/java/com/matejdro/pebbledialer/callactions/VolumeDownAction.java
public class VolumeDownAction extends CallAction
{
public static final int VOLUME_DOWN_ACTION_ID = 7;
public VolumeDownAction(CallModule callModule)
{
super(callModule);
}
@Override
public void executeAction()
{
if (getCallModule().getCallState() != CallModule.CallState.ESTABLISHED)
return;
AudioManager audioManager = (AudioManager) getCallModule().getService().getSystemService(Context.AUDIO_SERVICE);
audioManager.adjustStreamVolume(AudioManager.STREAM_VOICE_CALL, AudioManager.ADJUST_LOWER, 0);
}
@Override
public int getIcon()
{
return CallAction.ICON_BUTTON_VOLUME_DOWN;
}
}
app/src/main/java/com/matejdro/pebbledialer/callactions/CallAction.java
public abstract class CallAction
{
public static final int ICON_BUTTON_ANSWER = 0;
public static final int ICON_BUTTON_END_CALL = 1;
public static final int ICON_BUTTON_MIC_ON = 2;
public static final int ICON_BUTTON_MIC_OFF = 3;
public static final int ICON_BUTTON_SPEKAER_ON = 4;
public static final int ICON_BUTTON_SPEAKER_OFF = 5;
public static final int ICON_BUTTON_VOLUME_DOWN = 6;
public static final int ICON_BUTTON_VOLUME_UP = 7;
public static final int ICON_BLANK = 0xFF;
private CallModule callModule;
public CallAction(CallModule callModule)
{
this.callModule = callModule;
}
public CallModule getCallModule()
{
return callModule;
}
public void onPhoneOffhook()
{
}
public void onCallRinging()
{
}
public void onCallEnd()
{
}
public abstract void executeAction();
public abstract int getIcon();
}
app/src/main/java/com/matejdro/pebbledialer/notifications/JellybeanNotificationListener.java
@TargetApi(value = Build.VERSION_CODES.JELLY_BEAN_MR2)
public class JellybeanNotificationListener extends NotificationListenerService {
private static JellybeanNotificationListener instance = null;
@Override
public void onDestroy() {
Timber.d("Notification Listener stopped...");
super.onDestroy();
instance = null;
}
@Override
public void onCreate() {
Timber.d("Creating Notification Listener...");
super.onCreate();
instance = this;
}
public static boolean isActive()
{
return instance != null;
}
@TargetApi(value = Build.VERSION_CODES.LOLLIPOP)
public static boolean isPhoneInDoNotInterrupt()
{
if (instance == null)
return false;
int interruptionFilter = instance.getCurrentInterruptionFilter();
Timber.d("Interrupt filter: %d", interruptionFilter);
return interruptionFilter != NotificationListenerService.INTERRUPTION_FILTER_ALL && interruptionFilter != 0;
}
@Override
public void onNotificationPosted(final StatusBarNotification sbn) {
Timber.d("Got new jellybean notification");
NotificationHandler.newNotification(JellybeanNotificationListener.this, sbn.getPackageName(), sbn.getNotification());
}
@Override
public void onNotificationRemoved(StatusBarNotification sbn) {
}
}
app/src/main/java/com/matejdro/pebbledialer/callactions/ToggleSpeakerAction.java
public class ToggleSpeakerAction extends CallAction
{
public static final int TOGGLE_SPEAKER_ACTION_ID = 4;
private boolean speakerphoneEnabled = false;
public ToggleSpeakerAction(CallModule callModule)
{
super(callModule);
}
@Override
public void executeAction()
{
if (getCallModule().getCallState() != CallModule.CallState.ESTABLISHED)
return;
AudioManager audioManager = (AudioManager) getCallModule().getService().getSystemService(Context.AUDIO_SERVICE);
speakerphoneEnabled = !speakerphoneEnabled;
audioManager.setSpeakerphoneOn(speakerphoneEnabled);
getCallModule().updatePebble();
}
public boolean isSpeakerphoneEnabled()
{
return speakerphoneEnabled;
}
private void updateSpeakerphoneEnabled()
{
AudioManager audioManager = (AudioManager) getCallModule().getService().getSystemService(Context.AUDIO_SERVICE);
speakerphoneEnabled = audioManager.isSpeakerphoneOn();
}
@Override
public void onPhoneOffhook()
{
updateSpeakerphoneEnabled();
}
@Override
public int getIcon()
{
return speakerphoneEnabled ? ICON_BUTTON_SPEKAER_ON : ICON_BUTTON_SPEAKER_OFF;
}
public static ToggleSpeakerAction get(CallModule callModule)
{
return (ToggleSpeakerAction) callModule.getCallAction(TOGGLE_SPEAKER_ACTION_ID);
}
}
app/src/main/java/com/matejdro/pebbledialer/callactions/VolumeUpAction.java
public class VolumeUpAction extends CallAction
{
public static final int VOLUME_UP_ACTION_ID = 8;
public VolumeUpAction(CallModule callModule)
{
super(callModule);
}
@Override
public void executeAction()
{
if (getCallModule().getCallState() != CallModule.CallState.ESTABLISHED)
return;
AudioManager audioManager = (AudioManager) getCallModule().getService().getSystemService(Context.AUDIO_SERVICE);
audioManager.adjustStreamVolume(AudioManager.STREAM_VOICE_CALL, AudioManager.ADJUST_RAISE, 0);
}
@Override
public int getIcon()
{
return CallAction.ICON_BUTTON_VOLUME_UP;
}
}
app/src/main/java/com/matejdro/pebbledialer/callactions/DummyAction.java
public class DummyAction extends CallAction
{
public static final int DUMMY_ACTION_ID = 999;
public DummyAction(CallModule callModule)
{
super(callModule);
}
@Override
public void executeAction()
{
}
@Override
public int getIcon()
{
return CallAction.ICON_BLANK;
}
public static DummyAction get(CallModule callModule)
{
return (DummyAction) callModule.getCallAction(DUMMY_ACTION_ID);
}
}
package com.matejdro.pebbledialer.modules;
import android.app.PendingIntent;
import android.content.Intent;
import android.content.SharedPreferences;
import android.database.Cursor;
import android.graphics.Bitmap;
import android.net.Uri;
import android.os.Build;
import android.provider.ContactsContract;
import android.provider.MediaStore;
import android.service.notification.NotificationListenerService;
import android.telephony.TelephonyManager;
import android.util.SparseArray;
import com.getpebble.android.kit.util.PebbleDictionary;
import com.matejdro.pebblecommons.pebble.CommModule;
import com.matejdro.pebblecommons.pebble.PebbleCommunication;
import com.matejdro.pebblecommons.pebble.PebbleImageToolkit;
import com.matejdro.pebblecommons.pebble.PebbleTalkerService;
import com.matejdro.pebblecommons.pebble.PebbleUtil;
import com.matejdro.pebblecommons.util.ContactUtils;
import com.matejdro.pebblecommons.util.Size;
import com.matejdro.pebblecommons.util.TextUtil;
import com.matejdro.pebblecommons.vibration.PebbleVibrationPattern;
import com.matejdro.pebbledialer.callactions.AnswerCallAction;
import com.matejdro.pebbledialer.callactions.AnswerCallWithSpeakerAction;
import com.matejdro.pebbledialer.callactions.CallAction;
import com.matejdro.pebbledialer.callactions.DummyAction;
import com.matejdro.pebbledialer.callactions.EndCallAction;
import com.matejdro.pebbledialer.callactions.SMSReplyAction;
import com.matejdro.pebbledialer.callactions.ToggleMicrophoneAction;
import com.matejdro.pebbledialer.callactions.ToggleRingerAction;
import com.matejdro.pebbledialer.callactions.ToggleSpeakerAction;
import com.matejdro.pebbledialer.callactions.VolumeDownAction;
import com.matejdro.pebbledialer.callactions.VolumeUpAction;
import com.matejdro.pebbledialer.notifications.JellybeanNotificationListener;
import java.io.IOException;
import java.util.Calendar;
import java.util.List;
import timber.log.Timber;
public class CallModule extends CommModule
{
public static final String INTENT_CALL_STATUS = "CallStatus";
public static final String INTENT_ACTION_FROM_NOTIFICATION = "ActionFromNotification";
public static int MODULE_CALL = 1;
private SparseArray<CallAction> actions = new SparseArray<CallAction>();
private boolean updateRequired;
private boolean identityUpdateRequired;
private boolean callerNameUpdateRequired;
private int callerImageNextByte = -1;
private String number = "Outgoing Call";
private String name = null;
private String type = null;
private Bitmap callerImage = null;
private byte[] callerImageBytes;
private CallState callState = CallState.NO_CALL;
private boolean vibrating;
private boolean closeAutomaticallyAfterThisCall = true;
long callStartTime;
public CallModule(PebbleTalkerService service)
{
super(service);
service.registerIntent(INTENT_CALL_STATUS, this);
service.registerIntent(INTENT_ACTION_FROM_NOTIFICATION, this);
registerCallAction(new AnswerCallAction(this), AnswerCallAction.ANSWER_ACTION_ID);
registerCallAction(new EndCallAction(this), EndCallAction.END_CALL_ACTION_ID);Next line of code:

View file

@ -48,6 +48,8 @@ PHI3VISION_IDS = ['microsoft/phi-3-vision-128k-instruct']
QWENVL_IDS = ['Qwen/Qwen-VL-Chat']
MINICPM_V_IDS = ['openbmb/MiniCPM-V-2_6']
results = []
excludes = []
@ -77,20 +79,22 @@ def run_model_in_thread(model, in_out, tokenizer, result, warm_up, num_beams, in
actual_in_len, actual_out_len, load_time, model.peak_memory])
def get_continuation_input_str(in_len, tokenizer):
def get_continuation_input_str(in_len, tokenizer=None):
# keep 'utf-8' as character encoding mode
if in_len <= 2048:
input_str = open("prompt/summarize/cnn_39.txt", 'r', encoding='utf-8').read()
question = "Can you please summarize this article?"
prompt_format = "{}\nQuestion: {}\n"
full_prompt = prompt_format.format(input_str, question)
input_ids = tokenizer.encode(full_prompt, return_tensors="pt")
if input_ids.shape[1] < in_len:
return open(f"prompt/continuation/8192.txt", 'r', encoding='utf-8').read()
else:
half_idx = in_len // 2
input_ids_truncated = torch.cat((input_ids[:, :half_idx], input_ids[:, -(in_len-half_idx):]), dim=1)
return tokenizer.batch_decode(input_ids_truncated)[0]
if tokenizer is not None:
if in_len > 128 and in_len <= 4096:
if in_len <= 2048:
input_str = open("prompt/continuation/longbench_2k.txt", 'r', encoding='utf-8').read()
else:
input_str = open("prompt/continuation/longbench_4k.txt", 'r', encoding='utf-8').read()
input_ids = tokenizer.encode(input_str, return_tensors="pt")
if input_ids.shape[1] < in_len:
return open(f"prompt/continuation/8192.txt", 'r', encoding='utf-8').read()
else:
half_idx = in_len // 2
input_ids_truncated = torch.cat((input_ids[:, :half_idx], input_ids[:, -(in_len-half_idx):]), dim=1)
return tokenizer.batch_decode(input_ids_truncated)[0]
return open(f"prompt/continuation/8192.txt", 'r', encoding='utf-8').read()
@ -239,7 +243,7 @@ def run_native_int4(repo_id,
in_out_len = in_out.split("-")
in_len = int(in_out_len[0])
out_len = int(in_out_len[1])
input_str = get_continuation_input_str(in_len, tokenizer)
input_str = get_continuation_input_str(in_len)
# As different tokenizer has different encodings,
# slice the input_ids to ensure the prompt length is required length.
n_ctx = in_len + out_len if in_len + out_len > 512 else 512
@ -476,13 +480,15 @@ def run_transformer_int4_gpu(repo_id,
else:
model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=True,
trust_remote_code=True, use_cache=True,
torch_dtype=torch_dtype).eval()
cpu_embedding=cpu_embedding, torch_dtype=torch_dtype).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, cpu_embedding=cpu_embedding)
model = model.to('xpu')
elif origin_repo_id in LLAMA_IDS:
model = AutoModelForCausalLM.from_pretrained(model_path, load_in_low_bit=low_bit, trust_remote_code=True,
use_cache=True, cpu_embedding=cpu_embedding,
torch_dtype=torch_dtype).eval()
tokenizer = LlamaTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to('xpu')
elif origin_repo_id in PHI3VISION_IDS:
model = AutoModelForCausalLM.from_pretrained(model_path, optimize_model=True, load_in_low_bit=low_bit,
_attn_implementation="eager",
@ -490,6 +496,15 @@ def run_transformer_int4_gpu(repo_id,
trust_remote_code=True, use_cache=True,
cpu_embedding=cpu_embedding, torch_dtype=torch_dtype).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to('xpu')
elif origin_repo_id in MINICPM_V_IDS:
model = AutoModel.from_pretrained(model_path, load_in_low_bit=low_bit, optimize_model=True,
modules_to_not_convert=["vpm", "resampler"],
trust_remote_code=True, use_cache=True,
cpu_embedding=cpu_embedding, torch_dtype=torch_dtype).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to('xpu')
model = model.llm
else:
if "4bit" in repo_id:
model = AutoModelForCausalLM.load_low_bit(model_path, optimize_model=True,
@ -511,8 +526,7 @@ def run_transformer_int4_gpu(repo_id,
cpu_embedding=cpu_embedding,
torch_dtype=torch_dtype).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to('xpu')
model = model.to('xpu')
end = time.perf_counter()
load_time = end - st

View file

@ -396,6 +396,8 @@ class _BaseAutoModelClass:
from .lookup import lookup_generate
import types
model.lookup_generate = types.MethodType(lookup_generate, model)
if model.config.model_type == "minicpmv" and hasattr(model, 'llm'):
model.llm.lookup_generate = types.MethodType(lookup_generate, model.llm)
else:
# load default
model = cls.HF_Model.from_pretrained(*args, **kwargs)