Fix speech_paraformer issue with unexpected changes (#12416)
* Fix speech_paraformer issue with unexpected changes * Add paraformer version specified
This commit is contained in:
		
							parent
							
								
									a9cb70a71c
								
							
						
					
					
						commit
						ff3f7cb25f
					
				
					 3 changed files with 20 additions and 16 deletions
				
			
		| 
						 | 
				
			
			@ -37,8 +37,8 @@ pip install timm torch==2.1.2 torchvision==0.16.2
 | 
			
		|||
pip install BCEmbedding==0.1.5 transformers==4.40.0
 | 
			
		||||
 | 
			
		||||
# [optional] for Speech_Paraformer-Large
 | 
			
		||||
pip install -U funasr
 | 
			
		||||
pip install modelscope torch==2.1.2 torchaudio==2.1.2
 | 
			
		||||
pip install funasr==1.1.14
 | 
			
		||||
pip install modelscope==1.20.1 torch==2.1.2 torchaudio==2.1.2
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 2. Runtime Configurations
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -162,10 +162,12 @@ class _BaseAutoModelClass:
 | 
			
		|||
                model = cls.HF_Model.from_pretrained(*args, **kwargs)
 | 
			
		||||
            else:
 | 
			
		||||
                model = cls.HF_Model(*args, **kwargs)
 | 
			
		||||
            if hasattr(model, "config"):
 | 
			
		||||
                model.config.update({"bigdl_lcmu_enabled": False})
 | 
			
		||||
 | 
			
		||||
        logger.info(f"Converting model, it may takes up to several minutes ...")
 | 
			
		||||
 | 
			
		||||
        if hasattr(model, "config"):
 | 
			
		||||
            model.config.update({"optimize_model": optimize_model})
 | 
			
		||||
 | 
			
		||||
        if mock_device == "cpu":
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -294,17 +294,17 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
 | 
			
		|||
            torch.Tensor: result
 | 
			
		||||
        """
 | 
			
		||||
        backend_cls = self.backend_cls_prefill
 | 
			
		||||
        inputs = (x,
 | 
			
		||||
                  masks,
 | 
			
		||||
                  self.layer_norm_0_weight,
 | 
			
		||||
                  self.layer_norm_0_bias,
 | 
			
		||||
                  self.layer_norm_1_weight,
 | 
			
		||||
                  self.layer_norm_1_bias,
 | 
			
		||||
                  self.fsmn_weight,
 | 
			
		||||
                  self.qkv_bias,
 | 
			
		||||
                  self.out_bias,
 | 
			
		||||
                  self.w1_bias,
 | 
			
		||||
                  self.w2_bias,
 | 
			
		||||
        inputs = (x.to(torch.float16),
 | 
			
		||||
                  masks.to(torch.float16),
 | 
			
		||||
                  self.layer_norm_0_weight.to(torch.float16),
 | 
			
		||||
                  self.layer_norm_0_bias.to(torch.float16),
 | 
			
		||||
                  self.layer_norm_1_weight.to(torch.float16),
 | 
			
		||||
                  self.layer_norm_1_bias.to(torch.float16),
 | 
			
		||||
                  self.fsmn_weight.to(torch.float16),
 | 
			
		||||
                  self.qkv_bias.to(torch.float16),
 | 
			
		||||
                  self.out_bias.to(torch.float16),
 | 
			
		||||
                  self.w1_bias.to(torch.float16),
 | 
			
		||||
                  self.w2_bias.to(torch.float16),
 | 
			
		||||
                  )
 | 
			
		||||
 | 
			
		||||
        outputs = run_model(
 | 
			
		||||
| 
						 | 
				
			
			@ -431,6 +431,8 @@ class PrefillRunner:
 | 
			
		|||
        args = (xs_pad, masks)
 | 
			
		||||
        self.prefill_input_queue.put(args)
 | 
			
		||||
        xs_pad, masks = self.prefill_result_queue.get()
 | 
			
		||||
        xs_pad = xs_pad.to(torch.float32)
 | 
			
		||||
        masks = masks.to(torch.float32)
 | 
			
		||||
        return xs_pad, masks
 | 
			
		||||
 | 
			
		||||
    def shutdown(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -639,7 +641,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
 | 
			
		|||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.do_print = True
 | 
			
		||||
        self.do_print = do_print
 | 
			
		||||
 | 
			
		||||
        op_parameters = []
 | 
			
		||||
        for w in parameters:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue