更新 Mamba/mamba-main/train.py
This commit is contained in:
		
							parent
							
								
									9eea6c07af
								
							
						
					
					
						commit
						f65b091fac
					
				| @ -1,99 +1,99 @@ | |||||||
| import os | import os | ||||||
| import pandas as pd | import pandas as pd | ||||||
| from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, MambaConfig | from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, MambaConfig | ||||||
| from trl import SFTTrainer | from trl import SFTTrainer | ||||||
| from peft import LoraConfig | from peft import LoraConfig | ||||||
| from datasets import Dataset | from datasets import Dataset | ||||||
| 
 | import torch | ||||||
| # 设置环境变量来避免内存碎片化 | # 设置环境变量来避免内存碎片化 | ||||||
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" | ||||||
| 
 | 
 | ||||||
| # 数据文件夹路径 | # 数据文件夹路径 | ||||||
| data_folder = r'/mnt/Mamba/mamba-main/data/dataset' | data_folder = r'/mnt/Mamba/mamba-main/data/dataset' | ||||||
| 
 | 
 | ||||||
| # 检查路径是否存在 | # 检查路径是否存在 | ||||||
| if not os.path.exists(data_folder): | if not os.path.exists(data_folder): | ||||||
|     raise ValueError(f"路径不存在: {data_folder}") |     raise ValueError(f"路径不存在: {data_folder}") | ||||||
| 
 | 
 | ||||||
| # 加载分词器和模型 | # 加载分词器和模型 | ||||||
| path = "/mnt/Mamba/mamba-130m-hf"  # 模型路径 | path = "/mnt/Mamba/mamba-130m-hf"  # 模型路径 | ||||||
| tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True) | tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=True) | ||||||
| model = AutoModelForCausalLM.from_pretrained(path, local_files_only=True, num_labels=8, use_mambapy=True) | model = AutoModelForCausalLM.from_pretrained(path, local_files_only=True, num_labels=8, use_mambapy=True) | ||||||
| 
 | 
 | ||||||
| print("加载成功") | print("加载成功") | ||||||
| 
 | 
 | ||||||
| # 配置训练参数 | # 配置训练参数 | ||||||
| training_args = TrainingArguments( | training_args = TrainingArguments( | ||||||
|     output_dir="./results", |     output_dir="./results", | ||||||
|     num_train_epochs=3, |     num_train_epochs=3, | ||||||
|     per_device_train_batch_size=12,  # 减少批处理大小 |     per_device_train_batch_size=12,  # 减少批处理大小 | ||||||
|     logging_dir='./logs', |     logging_dir='./logs', | ||||||
|     logging_steps=10, |     logging_steps=10, | ||||||
|     learning_rate=2e-3, |     learning_rate=2e-3, | ||||||
|     gradient_accumulation_steps=2,  # 使用梯度累积减少显存占用 |     gradient_accumulation_steps=2,  # 使用梯度累积减少显存占用 | ||||||
|     fp16=True,  # 启用混合精度训练 |     fp16=True,  # 启用混合精度训练 | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| # LoRA配置 | # LoRA配置 | ||||||
| lora_config = LoraConfig( | lora_config = LoraConfig( | ||||||
|     r=8,  # 低秩分解的秩 |     r=8,  # 低秩分解的秩 | ||||||
|     target_modules=["x_proj", "embeddings", "in_proj", "out_proj"], |     target_modules=["x_proj", "embeddings", "in_proj", "out_proj"], | ||||||
|     task_type="SEQ_CLS",  # 序列分类任务类型 |     task_type="SEQ_CLS",  # 序列分类任务类型 | ||||||
|     bias="none" |     bias="none" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| # 初始化Trainer | # 初始化Trainer | ||||||
| trainer = SFTTrainer( | trainer = SFTTrainer( | ||||||
|     model=model, |     model=model, | ||||||
|     tokenizer=tokenizer, |     tokenizer=tokenizer, | ||||||
|     args=training_args, |     args=training_args, | ||||||
|     peft_config=lora_config, |     peft_config=lora_config, | ||||||
|     max_seq_length=512,  # 设置max_seq_length参数 |     max_seq_length=512,  # 设置max_seq_length参数 | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| # 分块加载和处理数据 | # 分块加载和处理数据 | ||||||
| chunksize = 40000  # 设置合适的分块大小,每次读取数据的行数 | chunksize = 40000  # 设置合适的分块大小,每次读取数据的行数 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def preprocess_data(chunk): | def preprocess_data(chunk): | ||||||
|     chunk = chunk.dropna()  # 处理缺失值 |     chunk = chunk.dropna()  # 处理缺失值 | ||||||
|     texts = chunk[["acc_x", "acc_y", "acc_z", "gyr_x", "gyr_y", "gyr_z", "mag_x", "mag_y", "mag_z"]].astype(str).apply( |     texts = chunk[["acc_x", "acc_y", "acc_z", "gyr_x", "gyr_y", "gyr_z", "mag_x", "mag_y", "mag_z"]].astype(str).apply( | ||||||
|         ' '.join, axis=1).tolist() |         ' '.join, axis=1).tolist() | ||||||
|     labels = chunk["Person_id"].astype(int).tolist()  # 确保标签是整数类型 |     labels = chunk["Person_id"].astype(int).tolist()  # 确保标签是整数类型 | ||||||
|     encodings = tokenizer(texts, truncation=True, padding=True, max_length=1024) |     encodings = tokenizer(texts, truncation=True, padding=True, max_length=1024) | ||||||
|     return {"input_ids": encodings["input_ids"], "attention_mask": encodings["attention_mask"], "labels": labels} |     return {"input_ids": encodings["input_ids"], "attention_mask": encodings["attention_mask"], "labels": labels} | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # 读取训练数据并进行训练 | # 读取训练数据并进行训练 | ||||||
| train_file_path = os.path.join(data_folder, 'train_data.csv') | train_file_path = os.path.join(data_folder, 'train_data.csv') | ||||||
| chunk_iter = pd.read_csv(train_file_path, chunksize=chunksize, header=0) | chunk_iter = pd.read_csv(train_file_path, chunksize=chunksize, header=0) | ||||||
| 
 | 
 | ||||||
| for chunk in chunk_iter: | for chunk in chunk_iter: | ||||||
|     # 数据预处理 |     # 数据预处理 | ||||||
|     processed_data = preprocess_data(chunk) |     processed_data = preprocess_data(chunk) | ||||||
|     dataset = Dataset.from_dict(processed_data) |     dataset = Dataset.from_dict(processed_data) | ||||||
| 
 | 
 | ||||||
|     # 训练模型 |     # 训练模型 | ||||||
|     trainer.train_dataset = dataset |     trainer.train_dataset = dataset | ||||||
|     trainer.train() |     trainer.train() | ||||||
| 
 | 
 | ||||||
|     # 清理CUDA缓存 |     # 清理CUDA缓存 | ||||||
|     torch.cuda.empty_cache() |     torch.cuda.empty_cache() | ||||||
| 
 | 
 | ||||||
| # 保存训练后的模型 | # 保存训练后的模型 | ||||||
| model.save_pretrained("./trained_model") | model.save_pretrained("./trained_model") | ||||||
| tokenizer.save_pretrained("./trained_model") | tokenizer.save_pretrained("./trained_model") | ||||||
| 
 | 
 | ||||||
| print("模型保存成功") | print("模型保存成功") | ||||||
| 
 | 
 | ||||||
| # 读取测试数据并进行预测 | # 读取测试数据并进行预测 | ||||||
| test_file_path = os.path.join(data_folder, 'test_data.csv') | test_file_path = os.path.join(data_folder, 'test_data.csv') | ||||||
| test_data = pd.read_csv(test_file_path, header=0) | test_data = pd.read_csv(test_file_path, header=0) | ||||||
| processed_test_data = preprocess_data(test_data) | processed_test_data = preprocess_data(test_data) | ||||||
| test_dataset = Dataset.from_dict(processed_test_data) | test_dataset = Dataset.from_dict(processed_test_data) | ||||||
| 
 | 
 | ||||||
| # 预测Person_id | # 预测Person_id | ||||||
| predictions = trainer.predict(test_dataset) | predictions = trainer.predict(test_dataset) | ||||||
| 
 | 
 | ||||||
| # 输出预测结果 | # 输出预测结果 | ||||||
| print(predictions) | print(predictions) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user