当前位置: 首页 > news >正文

随机生成pytorch算子测试序列且保证算子参数合法

随机生成pytorch算子测试序列且保证算子参数合法

  • 代码
  • 输出

背景:

1.一些对维度进行操作的算子的单算子测试,结果正常,但多个算子组合在一起,结果就不对。是否能给一个算子列表,随机生成它们的组合呢

功能描述:

1.此程序用于在 CUDA 环境中生成随机张量并对其施加一系列随机选择的操作

2.程序首先随机生成张量的形状和内容,然后随机选择一个操作(如 reshapetransposematmul 等),并生成适当的参数以执行该操作

3.最终输出变换后的张量并打印相关操作信息

4.整个过程在100次不同的种子下每次进行10次操作,以保证操作的多样性和结果的随机性

通过LLM多轮对话生成pytorch算子组合测试用例 小结初衷: 给一个算子列表,自动生成列表中算子的随机组合测试,可以覆盖不同的shape,支持任意多个算子的组合原计划(LLM全自动生成):
1.测试了qwen-max、kimi moonshot-v1-128k、ERNIE-4.0-8K、sparkai(3.5)、yi-large这几个模型(各家最新的模型)
2.这几个模型都能按要求生成单元测试用例,但几乎所有的代码运行都会出错(95%以上的错都是shape不匹配)
3.一些模型通过几次交互能修复bug,但整体上效果不理想
4.也许LLM对pytorch算子的约束不太了解,可以尝试将算子的接口文档告诉LLM,采用few shot的方式,是否有所改善。妥协的方案:
1.于是将这个需求细化,与GPT-4o多次交互,生成了这个功能模块的代码,功能正常。之后加到ut测试中

代码

import torch
import random
from functools import reduce
from operator import mul
import numpy as npmax_size = 4096  # 每个维度的最大大小
max_tensor_elements = 1*4096*4096  # 张量中元素的总数限制min_dim_size = 1  # 最小维度大小
max_dim_size = max_size  # 扩大这个范围可以更快生成符合要求的大小def generate_random_shape(dim, max_attempts=10):for _ in range(max_attempts):shape = [random.randint(min_dim_size, max_dim_size) for _ in range(dim)]if reduce(mul, shape, 1) <= max_tensor_elements:return tuple(shape)# 兜底策略,防止尝试次数用尽:再遍历生成的随机形状,逐个将维度缩小直到符合限制shape = [random.randint(1, max_size) for _ in range(dim)]current_elements = reduce(mul, shape, 1)while current_elements > max_tensor_elements:for i in range(len(shape)):if shape[i] > 1:shape[i] //= 2current_elements = reduce(mul, shape, 1)if current_elements <= max_tensor_elements:breakreturn tuple(shape)def generate_random_input(shape):return torch.randn(shape).to("cuda").half()def generate_random_operator(input_shape):operators = ['unsqueeze', 'repeat', 'permute', 'transpose', 'reshape', 'expand', 'contiguous', 'matmul', 'mul', 'concat',"view"]return random.choice(operators)def generate_random_reshape(input_shape):# 计算输入张量的总元素数total_elements = np.prod(input_shape)divisors = []# 找到 total_elements 的所有约数for i in range(1, int(np.sqrt(total_elements)) + 1):if total_elements % i == 0:divisors.append(i)if i != total_elements // i:divisors.append(total_elements // i)dimensions = []remaining_elements = total_elements# 随机选择新的维度并且保证元素数量不变while remaining_elements > 1 and len(dimensions) < len(input_shape):divisor = np.random.choice(divisors)dimensions.append(divisor)remaining_elements //= divisordivisors = [d for d in divisors if remaining_elements % d == 0]if remaining_elements > 1:dimensions.append(remaining_elements)    np.random.shuffle(dimensions)    return tuple(dimensions)def generate_reshape_params(tensor):return generate_random_reshape(tensor.shape)def random_transpose_params(tensor):return random.sample(range(tensor.dim()), 2)def generate_repeat_params(input_shape):while True:repeats = [random.randint(1, 4) for _ in input_shape]if reduce(mul, [dim * repeat for dim, repeat in zip(input_shape, repeats)], 1) <= max_tensor_elements:return tuple(repeats)def generate_expand_params(input_shape):expanded_shape = []while True:expanded_shape = [random.randint(min(2,dim), dim*2) if dim == 1 else dim for dim in input_shape]if reduce(mul, expanded_shape, 1) <= max_tensor_elements:breakreturn expanded_shapedef generate_random_operator_parameters(input_shape, operator, input_tensor):if operator == 'unsqueeze':return (random.randint(0, len(input_shape) - 1),)if operator == 'repeat':return generate_repeat_params(input_shape)if operator == 'permute':return random.sample(range(len(input_shape)), len(input_shape))if operator == 'transpose':return random_transpose_params(input_tensor)if operator in ['reshape',"view"]:return generate_reshape_params(input_tensor)if operator == 'expand':return generate_expand_params(input_shape)if operator == 'matmul':if input_tensor.dim() == 1:return ()return (input_tensor.size(-1), random.randint(1, max_size))if operator in ['contiguous','mul']:return ()if operator == 'concat':return (random.randint(0, len(input_shape) - 1),)def execute_operator(input_tensor, operator, operator_parameters):if operator == 'unsqueeze':return input_tensor.unsqueeze(*operator_parameters)if operator == 'repeat':return input_tensor.repeat(operator_parameters)if operator == 'permute':return input_tensor.permute(operator_parameters)if operator == 'transpose':return input_tensor.transpose(*operator_parameters)if operator == 'reshape':return input_tensor.reshape(operator_parameters)if operator == 'view':return input_tensor.view(operator_parameters)    if operator == 'expand':return input_tensor.expand(operator_parameters)if operator == 'contiguous':return input_tensor.contiguous()if operator == 'matmul':if input_tensor.dim() ==1:return input_tensorother = torch.randn(*operator_parameters).to(input_tensor.device).type_as(input_tensor)return torch.matmul(input_tensor, other)if operator == 'mul':return input_tensor * input_tensorif operator == 'concat':return torch.cat((input_tensor, input_tensor), dim=operator_parameters[0])def main():for seed in range(2):random.seed(seed)np.random.seed(seed)torch.random.manual_seed(seed)for i in range(10):input_shape = generate_random_shape(random.randint(2, 5))input_tensor = generate_random_input(input_shape)operator = generate_random_operator(input_shape)operator_parameters = generate_random_operator_parameters(input_shape, operator, input_tensor)output_tensor = execute_operator(input_tensor, operator, operator_parameters)print(f"seed:{seed:03d} seq:{i:02d} {operator:<10} input:{str(input_shape):<32} param:{str(operator_parameters):<32} output:{str(output_tensor.shape):<32}")print(output_tensor.cpu().numpy().reshape(-1)[:8])torch.cuda.empty_cache()
if __name__ == '__main__':main()

输出

seed:000 seq:00 repeat     input:(7, 42, 26, 36, 56)              param:(1, 1, 1, 1, 1)                  output:torch.Size([7, 42, 26, 36, 56])
seed:000 seq:01 view       input:(248, 227, 276)                  param:(92, 908, 186)                   output:torch.Size([92, 908, 186])
seed:000 seq:02 view       input:(18, 21, 51, 32, 17)             param:(17, 4536, 136)                  output:torch.Size([17, 4536, 136])
seed:000 seq:03 reshape    input:(2548, 3565)                     param:(644, 65, 217)                   output:torch.Size([644, 65, 217])
seed:000 seq:04 reshape    input:(46, 42, 14, 57, 7)              param:(28, 266, 3, 483)                output:torch.Size([28, 266, 3, 483])
seed:000 seq:05 contiguous input:(222, 100, 597)                  param:()                               output:torch.Size([222, 100, 597])
seed:000 seq:06 view       input:(15, 27, 56, 8, 59)              param:(3, 3, 20160, 1, 59)             output:torch.Size([3, 3, 20160, 1, 59])
seed:000 seq:07 view       input:(1461, 1161)                     param:(188469, 9)                      output:torch.Size([188469, 9])
seed:000 seq:08 reshape    input:(19, 29, 19, 17, 54)             param:(31407, 1, 3, 17, 6, 1)          output:torch.Size([31407, 1, 3, 17, 6, 1])
seed:000 seq:09 transpose  input:(12, 126, 46, 157)               param:[2, 3]                           output:torch.Size([12, 126, 157, 46])
[-0.581   0.568   1.187   2.46   -0.1392 -0.3362  0.2076 -0.662 ]
seed:001 seq:00 view       input:(119, 354, 236)                  param:(4, 1, 17, 146202)               output:torch.Size([4, 1, 17, 146202])
seed:001 seq:01 reshape    input:(60, 961, 178)                   param:(3, 3421160)                     output:torch.Size([3, 3421160])
seed:001 seq:02 expand     input:(16, 10, 34, 37, 58)             param:[16, 10, 34, 37, 58]             output:torch.Size([16, 10, 34, 37, 58])
seed:001 seq:03 concat     input:(12, 44, 12, 26, 55)             param:(1,)                             output:torch.Size([12, 88, 12, 26, 55])
seed:001 seq:04 expand     input:(48, 9, 28, 20, 68)              param:[48, 9, 28, 20, 68]              output:torch.Size([48, 9, 28, 20, 68])
seed:001 seq:05 repeat     input:(16, 16, 162, 233)               param:(1, 1, 1, 1)                     output:torch.Size([16, 16, 162, 233])
seed:001 seq:06 expand     input:(25, 426, 19, 63)                param:[25, 426, 19, 63]                output:torch.Size([25, 426, 19, 63])
seed:001 seq:07 permute    input:(153, 153, 380)                  param:[2, 1, 0]                        output:torch.Size([380, 153, 153])
seed:001 seq:08 permute    input:(3091, 1445)                     param:[1, 0]                           output:torch.Size([1445, 3091])
seed:001 seq:09 mul        input:(142, 254, 388)                  param:()                               output:torch.Size([142, 254, 388])
[3.31   0.3372 0.2354 0.1373 0.594  2.326  0.7344 2.16  ]

相关文章:

  • jpom ruoyi 发布后端
  • TypeScript 项目,自身 package 是 A,它引用了 B package。项目编译时,选择依赖版本的机制是什么?
  • 计算机毕业设计 | SpringBoot图书管理系统(附源码)
  • Qt界面开发软件使用介绍
  • react自用小技巧(持续更新中)
  • 最近关于工作与学习的一点思考
  • 深入解析Spring Cloud Consul:让微服务间的通信和管理更简单
  • CSS简述(1)
  • 使用LLaMA-Factory微调大模型
  • java mybatis处理大数据量,开启和配置二级缓存,及注意事项,已解决
  • Java 18新特性深度解析:提升开发效率与性能的革新工具
  • 重生之 SpringBoot3 入门保姆级学习(16、函数式 Web 编程)
  • 【NOIP提高组】方格取数
  • 如何将静态TCP/IP路由添加到Windows路由表?这里提供方法
  • Java线程中sleep()和wait()有什么区别
  • 【刷算法】求1+2+3+...+n
  • 【跃迁之路】【699天】程序员高效学习方法论探索系列(实验阶段456-2019.1.19)...
  • Angular2开发踩坑系列-生产环境编译
  • CAP 一致性协议及应用解析
  • Codepen 每日精选(2018-3-25)
  • go append函数以及写入
  • Golang-长连接-状态推送
  • Java读取Properties文件的六种方法
  • PHP 使用 Swoole - TaskWorker 实现异步操作 Mysql
  • RxJS 实现摩斯密码(Morse) 【内附脑图】
  • SpiderData 2019年2月23日 DApp数据排行榜
  • Yeoman_Bower_Grunt
  • 闭包,sync使用细节
  • 关键词挖掘技术哪家强(一)基于node.js技术开发一个关键字查询工具
  • 理解在java “”i=i++;”所发生的事情
  • 人脸识别最新开发经验demo
  • 在Docker Swarm上部署Apache Storm:第1部分
  • 自定义函数
  • python最赚钱的4个方向,你最心动的是哪个?
  • 国内唯一,阿里云入选全球区块链云服务报告,领先AWS、Google ...
  • (2024,RWKV-5/6,RNN,矩阵值注意力状态,数据依赖线性插值,LoRA,多语言分词器)Eagle 和 Finch
  • (NO.00004)iOS实现打砖块游戏(十二):伸缩自如,我是如意金箍棒(上)!
  • (附源码)spring boot网络空间安全实验教学示范中心网站 毕业设计 111454
  • (黑客游戏)HackTheGame1.21 过关攻略
  • (四)c52学习之旅-流水LED灯
  • (四)TensorRT | 基于 GPU 端的 Python 推理
  • (转)shell中括号的特殊用法 linux if多条件判断
  • (转载)虚幻引擎3--【UnrealScript教程】章节一:20.location和rotation
  • . Flume面试题
  • .NET MAUI Sqlite数据库操作(二)异步初始化方法
  • .Net OpenCVSharp生成灰度图和二值图
  • .net 程序发生了一个不可捕获的异常
  • .Net 中的反射(动态创建类型实例) - Part.4(转自http://www.tracefact.net/CLR-and-Framework/Reflection-Part4.aspx)...
  • .Net(C#)自定义WinForm控件之小结篇
  • .NET/C# 项目如何优雅地设置条件编译符号?
  • .NetCore Flurl.Http 升级到4.0后 https 无法建立SSL连接
  • .NET基础篇——反射的奥妙
  • .Net下C#针对Excel开发控件汇总(ClosedXML,EPPlus,NPOI)
  • @RequestMapping 的作用是什么?
  • @staticmethod和@classmethod的作用与区别