Spaces:
Paused
Paused
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -22,6 +22,7 @@ import transformers
|
|
| 22 |
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2Tokenizer, GPT2LMHeadModel
|
| 23 |
#import auto_gptq
|
| 24 |
#from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def reset_state():
|
|
@@ -99,6 +100,19 @@ def load_tokenizer_and_model(base_model,load_8bit=False):
|
|
| 99 |
return tokenizer,model,device
|
| 100 |
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
def load_tokenizer_and_model_gpt2(base_model,load_8bit=False):
|
| 103 |
if torch.cuda.is_available():
|
| 104 |
device = "cuda"
|
|
|
|
| 22 |
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2Tokenizer, GPT2LMHeadModel
|
| 23 |
#import auto_gptq
|
| 24 |
#from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
| 25 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
| 26 |
|
| 27 |
|
| 28 |
def reset_state():
|
|
|
|
| 100 |
return tokenizer,model,device
|
| 101 |
|
| 102 |
|
| 103 |
+
def load_tokenizer_and_model_Baize(base_model, load_8bit=True):
|
| 104 |
+
if torch.cuda.is_available():
|
| 105 |
+
device = "cuda"
|
| 106 |
+
else:
|
| 107 |
+
device = "cpu"
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
tokenizer = LlamaTokenizer.from_pretrained(base_model, add_eos_token=True, use_auth_token=True)
|
| 111 |
+
model = LlamaForCausalLM.from_pretrained(base_model, load_in_8bit=True, device_map="auto")
|
| 112 |
+
|
| 113 |
+
return tokenizer,model, device
|
| 114 |
+
|
| 115 |
+
|
| 116 |
def load_tokenizer_and_model_gpt2(base_model,load_8bit=False):
|
| 117 |
if torch.cuda.is_available():
|
| 118 |
device = "cuda"
|