compiled quantized model
3 removals
Words removed | 3 |
Total words | 2478 |
Words removed (%) | 0.12 |
480 lines
3 additions
Words added | 3 |
Total words | 2478 |
Words added (%) | 0.12 |
480 lines
P1225877385 - copy
P1225877483 - copy
FBID: 1167119611125918
FBID: 458720573262352
(An Untitled Masterwork)
(An Untitled Masterwork)
Visible to All Users
Visible to All Users
Author
Author
jerryzh
jerryzh
Created
Created
Mon Apr 29, 2024 12:35pm
Mon Apr 29, 2024 12:35pm
Subscribers
Subscribers
None
None
Fork this paste
Fork this paste
View raw file
View raw file
Edit paste
Edit paste
Delete paste
Delete paste
View as rendered diff
View as rendered diff
Copy content to clipboard
Copy content to clipboard
from ctypes import c_void_p, c_long
from ctypes import c_void_p, c_long
import torch
import torch
import math
import math
import random
import random
import os
import os
import tempfile
import tempfile
from math import inf, nan
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch import device, empty_strided
from torch._inductor.codecache import AsyncCompile
from torch._inductor.codecache import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
inductor_ops = torch.ops.inductor
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
alloc_from_pool = torch.ops.inductor._alloc_from_pool
alloc_from_pool = torch.ops.inductor._alloc_from_pool
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
async_compile = AsyncCompile()
async_compile = AsyncCompile()
# kernel path: /tmp/torchinductor_jerryzh/3m/c3mdxfx77gic7e55bygnquq4oboho4wkm2knnxiy3fbmx6ka5zxd.py
# kernel path: /tmp/torchinductor_jerryzh/3m/c3mdxfx77gic7e55bygnquq4oboho4wkm2knnxiy3fbmx6ka5zxd.py
# Source Nodes: [cat_1, input_1, x_5], Original ATen: [aten.add, aten.cat, aten.native_layer_norm]
# Source Nodes: [cat_1, input_1, x_5], Original ATen: [aten.add, aten.cat, aten.native_layer_norm]
# cat_1 => cat
# cat_1 => cat
# input_1 => add
# input_1 => add
# x_5 => convert_element_type, var_mean
# x_5 => convert_element_type, var_mean
triton_red_fused_add_cat_native_layer_norm_0 = async_compile.triton('triton_', '''
triton_red_fused_add_cat_native_layer_norm_0 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
from torch._inductor.utils import instance_descriptor
@triton_heuristics.reduction(
@triton_heuristics.reduction(
size_hints=[2048, 128],
size_hints=[2048, 128],
reduction_hint=ReductionHint.OUTER,
reduction_hint=ReductionHint.OUTER,
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 8), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(8,))]},
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*fp32', 5: '*fp32', 6: '*fp32', 7: 'i32', 8: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 8), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(8,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_cat_native_layer_norm_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_cat_native_layer_norm_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'}
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 1182
xnumel = 1182
rnumel = 128
rnumel = 128
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
rbase = tl.arange(0, RBLOCK)[None, :]
x1 = (xindex // 6)
x1 = (xindex // 6)
x0 = xindex % 6
x0 = xindex % 6
x3 = xindex
x3 = xindex
tmp21_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp21_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp21_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp21_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp21_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
tmp21_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)
for roffset in range(0, rnumel, RBLOCK):
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rindex = roffset + rbase
rmask = rindex < rnumel
rmask = rindex < rnumel
r2 = rindex
r2 = rindex
tmp17 = tl.load(in_ptr3 + (r2 + (128*x3)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp17 = tl.load(in_ptr3 + (r2 + (128*x3)), rmask & xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp0 = x1
tmp0 = x1
tmp1 = tl.full([1, 1], 0, tl.int64)
tmp1 = tl.full([1, 1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1, 1], 1, tl.int64)
tmp3 = tl.full([1, 1], 1, tl.int64)
tmp4 = tmp0 < tmp3
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (r2 + (128*x0)), rmask & tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (r2 + (128*x0)), rmask & tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.full(tmp5.shape, 0.0, tmp5.dtype)
tmp6 = tl.full(tmp5.shape, 0.0, tmp5.dtype)
tmp7 = tl.where(tmp4, tmp5, tmp6)
tmp7 = tl.where(tmp4, tmp5, tmp6)
tmp8 = tmp0 >= tmp3
tmp8 = tmp0 >= tmp3
tmp9 = tl.full([1, 1], 197, tl.int64)
tmp9 = tl.full([1, 1], 197, tl.int64)
tmp10 = tmp0 < tmp9
tmp10 = tmp0 < tmp9
tmp11 = tl.load(in_ptr1 + ((196*r2) + (25088*x0) + (((-1) + x1) % 196)), rmask & tmp8 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp11 = tl.load(in_ptr1 + ((196*r2) + (25088*x0) + (((-1) + x1) % 196)), rmask & tmp8 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp12 = tl.load(in_ptr2 + (r2 + (128*x0)), rmask & tmp8 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp12 = tl.load(in_ptr2 + (r2 + (128*x0)), rmask & tmp8 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp13 = tmp11 + tmp12
tmp13 = tmp11 + tmp12
tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
tmp15 = tl.where(tmp8, tmp13, tmp14)
tmp15 = tl.where(tmp8, tmp13, tmp14)
tmp16 = tl.where(tmp4, tmp7, tmp15)
tmp16 = tl.where(tmp4, tmp7, tmp15)
tmp18 = tmp16 + tmp17
tmp18 = tmp16 + tmp17
tmp19 = tmp18.to(tl.float32)
tmp19 = tmp18.to(tl.float32)
tmp20 = tl.broadcast_to(tmp19, [XBLOCK, RBLOCK])
tmp20 = tl.broadcast_to(tmp19, [XBLOCK, RBLOCK])
tmp21_mean_next, tmp21_m2_next, tmp21_weight_next = triton_helpers.welford_reduce(
tmp21_mean_next, tmp21_m2_next, tmp21_weight_next = triton_helpers.welford_reduce(
tmp20, tmp21_mean, tmp21_m2, tmp21_weight, roffset == 0
tmp20, tmp21_mean, tmp21_m2, tmp21_weight, roffset == 0
)
)
tmp21_mean = tl.where(rmask & xmask, tmp21_mean_next, tmp21_mean)
tmp21_mean = tl.where(rmask & xmask, tmp21_mean_next, tmp21_mean)
tmp21_m2 = tl.where(rmask & xmask, tmp21_m2_next, tmp21_m2)
tmp21_m2 = tl.where(rmask & xmask, tmp21_m2_next, tmp21_m2)
tmp21_weight = tl.where(rmask & xmask, tmp21_weight_next, tmp21_weight)
tmp21_weight = tl.where(rmask & xmask, tmp21_weight_next, tmp21_weight)
tmp21_tmp, tmp22_tmp, tmp23_tmp = triton_helpers.welford(
tmp21_tmp, tmp22_tmp, tmp23_tmp = triton_helpers.welford(
tmp21_mean, tmp21_m2, tmp21_weight, 1
tmp21_mean, tmp21_m2, tmp21_weight, 1
)
)
tmp21 = tmp21_tmp[:, None]
tmp21 = tmp21_tmp[:, None]
tmp22 = tmp22_tmp[:, None]
tmp22 = tmp22_tmp[:, None]
tmp23 = tmp23_tmp[:, None]
tmp23 = tmp23_tmp[:, None]
tl.store(out_ptr0 + (x3), tmp21, xmask)
tl.store(out_ptr0 + (x3), tmp21, xmask)
tl.store(out_ptr1 + (x3), tmp22, xmask)
tl.store(out_ptr1 + (x3), tmp22, xmask)
tl.store(out_ptr2 + (x3), tmp23, xmask)
tl.store(out_ptr2 + (x3), tmp23, xmask)
''', device_str='cuda')
''', device_str='cuda')
import triton
import triton
import triton.language as tl
import triton.language as tl
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
# kernel path: /tmp/torchinductor_jerryzh/rp/crprwyxhxbgfxjioi6d64h5rouzl2mdpzshcznykn37q2luuupzp.py
# kernel path: /tmp/torchinductor_jerryzh/rp/crprwyxhxbgfxjioi6d64h5rouzl2mdpzshcznykn37q2luuupzp.py
# Source Nodes: [cat_1, input_1, x_5], Original ATen: [aten.add, aten.cat, aten.native_layer_norm]
# Source Nodes: [cat_1, input_1, x_5], Original ATen: [aten.add, aten.cat, aten.native_layer_norm]
# cat_1 => cat
# cat_1 => cat
# input_1 => add
# input_1 => add
# x_5 => convert_element_type, var_mean
# x_5 => convert_element_type, var_mean
triton_per_fused_add_cat_native_layer_norm_1 = async_compile.triton('triton_', '''
triton_per_fused_add_cat_native_layer_norm_1 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
from torch._inductor.utils import instance_descriptor
@triton_heuristics.persistent_reduction(
@triton_heuristics.persistent_reduction(
size_hints=[256, 8],
size_hints=[256, 8],
reduction_hint=ReductionHint.INNER,
reduction_hint=ReductionHint.INNER,
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: 'i32', 6: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: 'i32', 6: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_cat_native_layer_norm_1', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_cat_native_layer_norm_1', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'}
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, xnumel, rnumel, XBLOCK : tl.constexpr):
xnumel = 197
xnumel = 197
rnumel = 6
rnumel = 6
RBLOCK: tl.constexpr = 8
RBLOCK: tl.constexpr = 8
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
rindex = tl.arange(0, RBLOCK)[None, :]
roffset = 0
roffset = 0
rmask = rindex < rnumel
rmask = rindex < rnumel
r1 = rindex
r1 = rindex
x0 = xindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (6*x0)), rmask & xmask, other=0.0)
tmp0 = tl.load(in_ptr0 + (r1 + (6*x0)), rmask & xmask, other=0.0)
tmp1 = tl.load(in_ptr1 + (r1 + (6*x0)), rmask & xmask, other=0.0)
tmp1 = tl.load(in_ptr1 + (r1 + (6*x0)), rmask & xmask, other=0.0)
tmp2 = tl.load(in_ptr2 + (r1 + (6*x0)), rmask & xmask, other=0.0)
tmp2 = tl.load(in_ptr2 + (r1 + (6*x0)), rmask & xmask, other=0.0)
tmp3 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp3 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
tmp4 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp4 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp5 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
tmp5 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
tmp7 = tl.where(rmask & xmask, tmp3, 0)
tmp7 = tl.where(rmask & xmask, tmp3, 0)
tmp8 = tl.where(rmask & xmask, tmp4, 0)
tmp8 = tl.where(rmask & xmask, tmp4, 0)
tmp9 = tl.where(rmask & xmask, tmp5, 0)
tmp9 = tl.where(rmask & xmask, tmp5, 0)
tmp10, tmp11, tmp12 = triton_helpers.welford(tmp7, tmp8, tmp9, 1)
tmp10, tmp11, tmp12 = triton_helpers.welford(tmp7, tmp8, tmp9, 1)
tmp13 = tmp10[:, None]
tmp13 = tmp10[:, None]
tmp14 = tmp11[:, None]
tmp14 = tmp11[:, None]
tmp15 = tmp12[:, None]
tmp15 = tmp12[:, None]
tl.store(out_ptr0 + (x0), tmp13, xmask)
tl.store(out_ptr0 + (x0), tmp13, xmask)
tl.store(out_ptr1 + (x0), tmp14, xmask)
tl.store(out_ptr1 + (x0), tmp14, xmask)
''', device_str='cuda')
''', device_str='cuda')
# kernel path: /tmp/torchinductor_jerryzh/4g/c4gelhmxwybhe35chfwqwajj7xo2bonxrt4coqtti75qtsxpc3dd.py
# kernel path: /tmp/torchinductor_jerryzh/4g/c4gelhmxwybhe35chfwqwajj7xo2bonxrt4coqtti75qtsxpc3dd.py
# Source Nodes: [cat_1, input_1, x_5], Original ATen: [aten.add, aten.cat, aten.native_layer_norm]
# Source Nodes: [cat_1, input_1, x_5], Original ATen: [aten.add, aten.cat, aten.native_layer_norm]
# cat_1 => cat
# cat_1 => cat
# input_1 => add
# input_1 => add
# x_5 => add_1, add_2, convert_element_type, convert_element_type_1, mul, mul_1, rsqrt, sub, var_mean
# x_5 => add_1, add_2, convert_element_type, convert_element_type_1, mul, mul_1, rsqrt, sub, var_mean
triton_poi_fused_add_cat_native_layer_norm_2 = async_compile.triton('triton_', '''
triton_poi_fused_add_cat_native_layer_norm_2 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
@triton_heuristics.pointwise(
size_hints=[262144],
size_hints=[262144],
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*fp32', 5: '*fp32', 6: '*bf16', 7: '*bf16', 8: '*bf16', 9: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(9,))]},
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: '*bf16', 4: '*fp32', 5: '*fp32', 6: '*bf16', 7: '*bf16', 8: '*bf16', 9: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(9,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_native_layer_norm_2', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_cat_native_layer_norm_2', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'},
min_elem_per_thread=0
min_elem_per_thread=0
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr1, xnumel, XBLOCK : tl.constexpr):
def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr1, xnumel, XBLOCK : tl.constexpr):
xnumel = 151296
xnumel = 151296
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
xmask = xindex < xnumel
x1 = (xindex // 768)
x1 = (xindex // 768)
x0 = xindex % 768
x0 = xindex % 768
x2 = xindex
x2 = xindex
tmp17 = tl.load(in_ptr3 + (x2), xmask).to(tl.float32)
tmp17 = tl.load(in_ptr3 + (x2), xmask).to(tl.float32)
tmp20 = tl.load(in_ptr4 + (x1), xmask, eviction_policy='evict_last')
tmp20 = tl.load(in_ptr4 + (x1), xmask, eviction_policy='evict_last')
tmp22 = tl.load(in_ptr5 + (x1), xmask, eviction_policy='evict_last')
tmp22 = tl.load(in_ptr5 + (x1), xmask, eviction_policy='evict_last')
tmp29 = tl.load(in_ptr6 + (x0), xmask, eviction_policy='evict_last').to(tl.float32)
tmp29 = tl.load(in_ptr6 + (x0), xmask, eviction_policy='evict_last').to(tl.float32)
tmp32 = tl.load(in_ptr7 + (x0), xmask, eviction_policy='evict_last').to(tl.float32)
tmp32 = tl.load(in_ptr7 + (x0), xmask, eviction_policy='evict_last').to(tl.float32)
tmp0 = x1
tmp0 = x1
tmp1 = tl.full([1], 0, tl.int64)
tmp1 = tl.full([1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1], 1, tl.int64)
tmp3 = tl.full([1], 1, tl.int64)
tmp4 = tmp0 < tmp3
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (x0), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (x0), tmp4 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.full(tmp5.shape, 0.0, tmp5.dtype)
tmp6 = tl.full(tmp5.shape, 0.0, tmp5.dtype)
tmp7 = tl.where(tmp4, tmp5, tmp6)
tmp7 = tl.where(tmp4, tmp5, tmp6)
tmp8 = tmp0 >= tmp3
tmp8 = tmp0 >= tmp3
tmp9 = tl.full([1], 197, tl.int64)
tmp9 = tl.full([1], 197, tl.int64)
tmp10 = tmp0 < tmp9
tmp10 = tmp0 < tmp9
tmp11 = tl.load(in_ptr1 + ((196*x0) + (((-1) + x1) % 196)), tmp8 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp11 = tl.load(in_ptr1 + ((196*x0) + (((-1) + x1) % 196)), tmp8 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp12 = tl.load(in_ptr2 + (x0), tmp8 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp12 = tl.load(in_ptr2 + (x0), tmp8 & xmask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp13 = tmp11 + tmp12
tmp13 = tmp11 + tmp12
tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
tmp14 = tl.full(tmp13.shape, 0.0, tmp13.dtype)
tmp15 = tl.where(tmp8, tmp13, tmp14)
tmp15 = tl.where(tmp8, tmp13, tmp14)
tmp16 = tl.where(tmp4, tmp7, tmp15)
tmp16 = tl.where(tmp4, tmp7, tmp15)
tmp18 = tmp16 + tmp17
tmp18 = tmp16 + tmp17
tmp19 = tmp18.to(tl.float32)
tmp19 = tmp18.to(tl.float32)
tmp21 = tmp19 - tmp20
tmp21 = tmp19 - tmp20
tmp23 = 768.0
tmp23 = 768.0
tmp24 = tmp22 / tmp23
tmp24 = tmp22 / tmp23
tmp25 = 1e-06
tmp25 = 1e-06
tmp26 = tmp24 + tmp25
tmp26 = tmp24 + tmp25
tmp27 = libdevice.rsqrt(tmp26)
tmp27 = libdevice.rsqrt(tmp26)
tmp28 = tmp21 * tmp27
tmp28 = tmp21 * tmp27
tmp30 = tmp29.to(tl.float32)
tmp30 = tmp29.to(tl.float32)
tmp31 = tmp28 * tmp30
tmp31 = tmp28 * tmp30
tmp33 = tmp32.to(tl.float32)
tmp33 = tmp32.to(tl.float32)
tmp34 = tmp31 + tmp33
tmp34 = tmp31 + tmp33
tmp35 = tmp34.to(tl.float32)
tmp35 = tmp34.to(tl.float32)
tl.store(out_ptr1 + (x2), tmp35, xmask)
tl.store(out_ptr1 + (x2), tmp35, xmask)
''', device_str='cuda')
''', device_str='cuda')
# kernel path: /tmp/torchinductor_jerryzh/f6/cf6bv6b4vfuvbectvwbrf5n7kguat2tmat24wzdu32injuk3czgi.py
# kernel path: /tmp/torchinductor_jerryzh/f6/cf6bv6b4vfuvbectvwbrf5n7kguat2tmat24wzdu32injuk3czgi.py
# Source Nodes: [l__self___encoder_layers_encoder_layer_0_self_attention], Original ATen: [aten._scaled_dot_product_flash_attention]
# Source Nodes: [l__self___encoder_layers_encoder_layer_0_self_attention], Original ATen: [aten._scaled_dot_product_flash_attention]
# l__self___encoder_layers_encoder_layer_0_self_attention => _scaled_dot_product_flash_attention
# l__self___encoder_layers_encoder_layer_0_self_attention => _scaled_dot_product_flash_attention
triton_poi_fused__scaled_dot_product_flash_attention_3 = async_compile.triton('triton_', '''
triton_poi_fused__scaled_dot_product_flash_attention_3 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
@triton_heuristics.pointwise(
size_hints=[262144],
size_hints=[262144],
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(3,))]},
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(3,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention_3', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention_3', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'},
min_elem_per_thread=0
min_elem_per_thread=0
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 151296
xnumel = 151296
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
xmask = xindex < xnumel
x0 = xindex % 768
x0 = xindex % 768
x1 = (xindex // 768)
x1 = (xindex // 768)
x2 = xindex
x2 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (2304*x1)), xmask).to(tl.float32)
tmp0 = tl.load(in_ptr0 + (x0 + (2304*x1)), xmask).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last').to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0), xmask, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tmp2 = tmp0 + tmp1
tl.store(out_ptr0 + (x2), tmp2, xmask)
tl.store(out_ptr0 + (x2), tmp2, xmask)
''', device_str='cuda')
''', device_str='cuda')
# kernel path: /tmp/torchinductor_jerryzh/ni/cnibta3iuvpy5fbssjg635n7b6h3fa5rwyhvzhm3ouugyisfup6w.py
# kernel path: /tmp/torchinductor_jerryzh/ni/cnibta3iuvpy5fbssjg635n7b6h3fa5rwyhvzhm3ouugyisfup6w.py
# Source Nodes: [l__self___encoder_layers_encoder_layer_0_self_attention], Original ATen: [aten._scaled_dot_product_flash_attention]
# Source Nodes: [l__self___encoder_layers_encoder_layer_0_self_attention], Original ATen: [aten._scaled_dot_product_flash_attention]
# l__self___encoder_layers_encoder_layer_0_self_attention => _scaled_dot_product_flash_attention
# l__self___encoder_layers_encoder_layer_0_self_attention => _scaled_dot_product_flash_attention
triton_poi_fused__scaled_dot_product_flash_attention_4 = async_compile.triton('triton_', '''
triton_poi_fused__scaled_dot_product_flash_attention_4 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
@triton_heuristics.pointwise(
size_hints=[262144],
size_hints=[262144],
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(3,))]},
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(3,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention_4', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention_4', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'},
min_elem_per_thread=0
min_elem_per_thread=0
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 151296
xnumel = 151296
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
xmask = xindex < xnumel
x0 = xindex % 768
x0 = xindex % 768
x1 = (xindex // 768)
x1 = (xindex // 768)
x2 = xindex
x2 = xindex
tmp0 = tl.load(in_ptr0 + (768 + x0 + (2304*x1)), xmask).to(tl.float32)
tmp0 = tl.load(in_ptr0 + (768 + x0 + (2304*x1)), xmask).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (768 + x0), xmask, eviction_policy='evict_last').to(tl.float32)
tmp1 = tl.load(in_ptr1 + (768 + x0), xmask, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tmp2 = tmp0 + tmp1
tl.store(out_ptr0 + (x2), tmp2, xmask)
tl.store(out_ptr0 + (x2), tmp2, xmask)
''', device_str='cuda')
''', device_str='cuda')
# kernel path: /tmp/torchinductor_jerryzh/vm/cvm3ucqwmyyakgkf2cfvaeylq3ktp3i7s5emlef7zyz44bzlk4jt.py
# kernel path: /tmp/torchinductor_jerryzh/vm/cvm3ucqwmyyakgkf2cfvaeylq3ktp3i7s5emlef7zyz44bzlk4jt.py
# Source Nodes: [l__self___encoder_layers_encoder_layer_0_self_attention], Original ATen: [aten._scaled_dot_product_flash_attention]
# Source Nodes: [l__self___encoder_layers_encoder_layer_0_self_attention], Original ATen: [aten._scaled_dot_product_flash_attention]
# l__self___encoder_layers_encoder_layer_0_self_attention => _scaled_dot_product_flash_attention
# l__self___encoder_layers_encoder_layer_0_self_attention => _scaled_dot_product_flash_attention
triton_poi_fused__scaled_dot_product_flash_attention_5 = async_compile.triton('triton_', '''
triton_poi_fused__scaled_dot_product_flash_attention_5 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
from torch._inductor.utils import instance_descriptor
@triton_heuristics.pointwise(
@triton_heuristics.pointwise(
size_hints=[262144],
size_hints=[262144],
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(3,))]},
triton_meta={'signature': {0: '*bf16', 1: '*bf16', 2: '*bf16', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(3,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention_5', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__scaled_dot_product_flash_attention_5', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'},
min_elem_per_thread=0
min_elem_per_thread=0
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 151296
xnumel = 151296
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
xmask = xindex < xnumel
x0 = xindex % 768
x0 = xindex % 768
x1 = (xindex // 768)
x1 = (xindex // 768)
x2 = xindex
x2 = xindex
tmp0 = tl.load(in_ptr0 + (1536 + x0 + (2304*x1)), xmask).to(tl.float32)
tmp0 = tl.load(in_ptr0 + (1536 + x0 + (2304*x1)), xmask).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (1536 + x0), xmask, eviction_policy='evict_last').to(tl.float32)
tmp1 = tl.load(in_ptr1 + (1536 + x0), xmask, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tmp2 = tmp0 + tmp1
tl.store(out_ptr0 + (x2), tmp2, xmask)
tl.store(out_ptr0 + (x2), tmp2, xmask)
''', device_str='cuda')
''', device_str='cuda')
# kernel path: /tmp/torchinductor_jerryzh/fv/cfv2vtxyxm4ieahpqcdzbdf2o2u6h4ju6comfyd5t3tuvdlqmnv5.py
# kernel path: /tmp/torchinductor_jerryzh/fv/cfv2vtxyxm4ieahpqcdzbdf2o2u6h4ju6comfyd5t3tuvdlqmnv5.py
# Source Nodes: [l__self___encoder_layers_encoder_layer_0_self_attention], Original ATen: [aten._to_copy, aten.abs, aten.amax, aten.clamp, aten.div, aten.round]
# Source Nodes: [l__self___encoder_layers_encoder_layer_0_self_attention], Original ATen: [aten._to_copy, aten.abs, aten.amax, aten.clamp, aten.div, aten.round]
# l__self___encoder_layers_encoder_layer_0_self_attention => abs_1, amax, clamp_max, clamp_min, clamp_min_1, convert_element_type_5, convert_element_type_6, convert_element_type_7, convert_element_type_8, convert_element_type_9, div, div_1, round_1
# l__self___encoder_layers_encoder_layer_0_self_attention => abs_1, amax, clamp_max, clamp_min, clamp_min_1, convert_element_type_5, convert_element_type_6, convert_element_type_7, convert_element_type_8, convert_element_type_9, div, div_1, round_1
triton_per_fused__to_copy_abs_amax_clamp_div_round_6 = async_compile.triton('triton_', '''
triton_per_fused__to_copy_abs_amax_clamp_div_round_6 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
from torch._inductor.utils import instance_descriptor
@triton_heuristics.persistent_reduction(
@triton_heuristics.persistent_reduction(
size_hints=[256, 1024],
size_hints=[256, 1024],
reduction_hint=ReductionHint.INNER,
reduction_hint=ReductionHint.INNER,
filename=__file__,
filename=__file__,
triton_meta={'signature': {0: '*bf16', 1: '*i8', 2: '*bf16', 3: 'i32', 4: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 4), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(4,))]},
triton_meta={'signature': {0: '*bf16', 1: '*i8', 2: '*bf16', 3: 'i32', 4: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 4), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(4,))]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_abs_amax_clamp_div_round_6', 'mutated_arg_names': [], 'no_x_dim': True, 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'}
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_abs_amax_clamp_div_round_6', 'mutated_arg_names': [], 'no_x_dim': True, 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'}
)
)
@triton.jit
@triton.jit
def triton_(in_ptr0, out_ptr1, out_ptr2, xnumel, rnumel):
def triton_(in_ptr0, out_ptr1, out_ptr2, xnumel, rnumel):
xnumel = 197
xnumel = 197
XBLOCK: tl.constexpr = 1
XBLOCK: tl.constexpr = 1
rnumel = 768
rnumel = 768
RBLOCK: tl.constexpr = 1024
RBLOCK: tl.constexpr = 1024
xoffset = tl.program_id(0) * XBLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xindex = tl.full([1], xoffset, tl.int32)
xmask = xindex < xnumel
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[:]
rindex = tl.arange(0, RBLOCK)[:]
roffset = 0
roffset = 0
rmask = rindex < rnumel
rmask = rindex < rnumel
r1 = rindex
r1 = rindex
x0 = xindex
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask & xmask, other=0.0).to(tl.float32)
tmp0 = tl.load(in_ptr0 + (r1 + (768*x0)), rmask & xmask, other=0.0).to(tl.float32)
tmp1 = tl_math.abs(tmp0)
tmp1 = tl_math.abs(tmp0)
tmp2 = tl.broadcast_to(tmp1, [RBLOCK])
tmp2 = tl.broadcast_to(tmp1, [RBLOCK])
tmp4 = tl.where(rmask & xmask, tmp2, float("-inf"))
tmp4 = tl.where(rmask & xmask, tmp2, float("-inf"))
tmp5 = triton_helpers.promote_to_tensor(triton_helpers.max2(tmp4, 0))
tmp5 = triton_helpers.promote_to_tensor(triton_helpers.max2(tmp4, 0))
tmp6 = tmp5.to(tl.float32)
tmp6 = tmp5.to(tl.float32)
tmp7 = 1e-05
tmp7 = 1e-05
tmp8 = triton_helpers.maximum(tmp6, tmp7)
tmp8 = triton_helpers.maximum(tmp6, tmp7)
tmp9 = tmp8.to(tl.float32)
tmp9 = tmp8.to(tl.float32)
tmp10 = 127.0
tmp10 = 127.0
tmp11 = tmp9 / tmp10
tmp11 = tmp9 / tmp10
tmp12 = tmp0 / tmp11
tmp12 = tmp0 / tmp11
tmp13 = libdevice.nearbyint(tmp12)
tmp13 = libdevice.nearbyint(tmp12)
tmp14 = tmp13.to(tl.float32)
tmp14 = tmp13.to(tl.float32)
tmp15 = -127.0
tmp15 = -127.0
tmp16 = triton_helpers.maximum(tmp14, tmp15)
tmp16 = triton_helpers.maximum(tmp14, tmp15)
tmp17 = triton_helpers.minimum(tmp16, tmp10)
tmp17 = triton_helpers.minimum(tmp16, tmp10)
tmp18 = tmp17.to(tl.float32)
tmp18 = tmp17.to(tl.float32)
tmp19 = tmp18.to(tl.int8)
tmp19 = tmp18.to(tl.int8)
tl.store(out_ptr1 + (r1 + (768*x0)), tmp19, rmask & xmask)
tl.store(out_ptr1 + (r1 + (768*x0)), tmp19, rmask & xmask)
tl.store(out_ptr2 + (x0), tmp11, xmask)
tl.store(out_ptr2 + (x0), tmp11, xmask)
''', device_str='cuda')
''', device_str='cuda')
# kernel path: /tmp/torchinductor_jerryzh/tb/ctbetqksq5wdquspt6aqeaeczbvkly47okhymmcfwjlfwji4avos.py
# kernel path: /tmp/torchinductor_jerryzh/tb/ctbetqksq5wdquspt6aqeaeczbvkly47okhymmcfwjlfwji4avos.py
# Source Nodes: [l__self___encoder_layers_encoder_layer_0_self_attention], Original ATen: [aten._to_copy, aten.clamp, aten.div, aten.mul, aten.round, aten.view]
# Source Nodes: [l__self___encoder_layers_encoder_layer_0_self_attention], Original ATen: [aten._to_copy, aten.clamp, aten.div, aten.mul, aten.round, aten.view]
# l__self___encoder_layers_encoder_layer_0_self_attention => clamp_max, clamp_min, clamp_min_1, convert_element_type_5, convert_element_type_6, convert_element_type_7, convert_element_type_8, convert_element_type_9, div, div_1, fused_int_mm_mul_36, round_1, view_11
# l__self___encoder_layers_encoder_layer_0_self_attention => clamp_max, clamp_min, clamp_min_1, convert_element_type_5, convert_element_type_6, convert_element_type_7, convert_element_type_8, convert_element_type_9, div, div_1, fused_int_mm_mul_36, round_1, view_11
triton_tem_fused__to_copy_clamp_div_mul_round_view_7 = async_compile.triton('triton_', '''
triton_tem_fused__to_copy_clamp_div_mul_round_view_7 = async_compile.triton('triton_', '''
import triton
import triton
import triton.language as tl
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor
from torch._inductor.utils import instance_descriptor
@triton_heuristics.template(
@triton_heuristics.template(
num_stages=5,
num_stages=5,
num_warps=8,
num_warps=8,
triton_meta={'signature': {0: '*i8', 1: '*i8', 2: '*bf16', 3: '*bf16'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
triton_meta={'signature': {0: '*i8', 1: '*i8', 2: '*bf16', 3: '*bf16'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
inductor_meta={'kernel_name': 'triton_tem_fused__to_copy_clamp_div_mul_round_view_7', 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'},
inductor_meta={'kernel_name': 'triton_tem_fused__to_copy_clamp_div_mul_round_view_7', 'backend_hash': 'd389db67f581f05462832d65867d4147e9aefba5e2e1d883bbadd9e1860863a6'},
)
)
@triton.jit
@triton.jit
def triton_(arg_A, arg_B, in_ptr2, out_ptr0):
def triton_(arg_A, arg_B, in_ptr2, out_ptr0):
GROUP_M : tl.constexpr = 8
GROUP_M : tl.constexpr = 8
EVEN_K : tl.constexpr = True
EVEN_K : tl.constexpr = True
ALLOW_TF32 : tl.constexpr = False
ALLOW_TF32 : tl.constexpr = False
ACC_TYPE : tl.constexpr = tl.int32
ACC_TYPE : tl.constexpr = tl.int32
B_PROLOGUE_CAST_TYPE : tl.constexpr = None
B_PROLOGUE_CAST_TYPE : tl.constexpr = None
BLOCK_M : tl.constexpr = 64
BLOCK_M : tl.constexpr = 64
BLOCK_N : tl.constexpr = 32
BLOCK_N : tl.constexpr = 32
BLOCK_K : tl.constexpr = 32
BLOCK_K : tl.constexpr = 32
A = arg_A
A = arg_A
B = arg_B
B = arg_B
M = 197
M = 197
N = 768
N = 768
K = 768
K = 768
if M * N == 0:
if M * N == 0:
# early exit due to zero-size input(s)
# early exit due to zero-size input(s)
return
return
stride_am = 768
stride_am = 768
stride_ak = 1
stride_ak = 1
stride_bk = 1
stride_bk = 1
stride_bn = 768
stride_bn = 768
# based on triton.ops.matmul
# based on triton.ops.matmul
pid = tl.program_id(0)
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
width = GROUP_M * grid_n
group_id = pid // width
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
if EVEN_K:
a = tl.load(A)
a = tl.load(A)
b = tl.load(B)
b = tl.load(B)
else:
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
if B_PROLOGUE_CAST_TYPE is not None:
if B_PROLOGUE_CAST_TYPE is not None:
b = b.to(B_PROLOGUE_CAST_TYPE)
b = b.to(B_PROLOGUE_CAST_TYPE)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_m = rm[:, None]
idx_n = rn[None, :]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
# inductor generates a suffix
xindex = idx_n + (768*idx_m)
xindex = idx_n + (768*idx_m)
tmp0 = tl.load(in_ptr2 + (tl.broadcast_to(idx_m, mask.shape)), mask, eviction_policy='evict_last').to(tl.float32)
tmp0 = tl.load(in_ptr2 + (tl.broadcast_to(idx_m, mask.shape)), mask, eviction_policy='evict_last').to(tl.float32)
tl.store(out_ptr0 + (tl.broadcast_to(xindex, mask.shape)), acc * tmp0, mask)
tl.store(out_ptr0 + (tl.broadcast_to(xindex, mask.shape)), acc * tmp0, mask)
''', device_str='cuda')
''', device_str='cuda')
import torch._inductor.kernel.mm_common
import torch._inductor.kernel.mm_common
meta0 = {'GROUP_M': 8, 'EVEN_K': True, 'ALLOW_TF32': False, 'ACC_TYPE': 'tl.int32', 'B_PROLOGUE_CAST_TYPE': None, 'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32}
meta0 = {'GROUP_M': 8, 'EVEN_K': True, 'ALLOW_TF32': False, 'ACC_TYPE': 'tl.int32', 'B_PROLOGUE_CAST_TYPE': None, 'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32}
# kernel path: /tmp/torchinductor_jerryzh/cr/ccr74bfho3ok5o2tyb5aq3refnm7mumphtnf7anixsxqq2ylylnx.py
# kernel path: /tmp/torchinductor_jerryzh/cr/ccr74bfho3ok5o2tyb5aq3refnm7mumphtnf7anixsxqq2ylylnx.py
# Source Nodes: [cat_1, input_1, l__self___encoder_layers_encoder_layer_0_mlp_0, x_7, x_8, y], Original ATen: [aten._to_copy, aten.abs, aten.add, aten.amax, aten.cat, aten.clamp, aten.clone, aten.div, aten.native_layer_norm, aten.round]
# Source Nodes: [cat_1, input_1, l__self___encoder_layers_encoder_layer_0_mlp_0, x_7, x_8, y], Original ATen: [aten._to_copy, aten.abs, aten.add, aten.amax, aten.cat, aten.clamp, aten.clone, aten.div, aten.native_layer_norm, aten.round]
# cat_1 => cat
# cat_1 => cat
# input_1 => add
# input_1 => add
# l__self___encoder_layers_encoder_layer_0_mlp_0 => abs_2, amax_1, clamp_max_1, clamp_min_2, clamp_min_3, convert_element_type_12, convert_element_type_13, convert_element_type_14, convert_element_type_15, convert_element_type_16, div_2, div_3, round_2
# l__self___encoder_layers_encoder_layer_0_mlp_0 => abs_2, amax_1, clamp_max_1, clamp_min_2, clamp_min_3, convert_element_type_12, convert_element_type_13, convert_element_type_14, convert_element_type_15, convert_element_type_16, div_2, div_3, round_2
# x_7 => clon
# x_7 => clone