Skip to content

Commit f696f55

Browse files
committed
[CK][Examples] Fixing stride issues in ck examples 14/65/68/69 by workaround - Bypassing hostTensor validation
-Fixing args num in ck examples 68/69 Signed-off-by: Michal Kulikowski <[email protected]>
1 parent 5122637 commit f696f55

File tree

9 files changed

+30
-22
lines changed

9 files changed

+30
-22
lines changed

example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@ using ::ck::Tensor;
2727
template <ck::index_t... Is>
2828
using S = ck::Sequence<Is...>;
2929

30-
using I8 = int8_t;
31-
using I32 = int32_t;
32-
using Row = ck::tensor_layout::gemm::RowMajor;
33-
using Col = ck::tensor_layout::gemm::ColumnMajor;
30+
using I8 = int8_t;
31+
using I32 = int32_t;
32+
using Row = ck::tensor_layout::gemm::RowMajor;
33+
using Col = ck::tensor_layout::gemm::ColumnMajor;
34+
using Bypass = ck::tensor_layout::BypassLayoutVerification;
3435

3536
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
3637
using ActivationOp = PassThrough;
@@ -125,11 +126,11 @@ int main(int /* argc */, char* /* argv */[])
125126

126127
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
127128
{
128-
return HostTensorDescriptor({row, col}, {stride, 1_uz});
129+
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
129130
}
130131
else
131132
{
132-
return HostTensorDescriptor({row, col}, {1_uz, stride});
133+
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
133134
}
134135
};
135136

example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ using S = ck::Sequence<Is...>;
3131
using F16 = ck::half_t;
3232
using F32 = float;
3333

34-
using Row = ck::tensor_layout::gemm::RowMajor;
35-
using Col = ck::tensor_layout::gemm::ColumnMajor;
34+
using Row = ck::tensor_layout::gemm::RowMajor;
35+
using Col = ck::tensor_layout::gemm::ColumnMajor;
36+
using Bypass = ck::tensor_layout::BypassLayoutVerification;
3637

3738
using A0DataType = F16;
3839
using B0DataType = F16;
@@ -139,11 +140,11 @@ int main(int argc, char* argv[])
139140

140141
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
141142
{
142-
return HostTensorDescriptor({row, col}, {stride, 1_uz});
143+
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
143144
}
144145
else
145146
{
146-
return HostTensorDescriptor({row, col}, {1_uz, stride});
147+
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
147148
}
148149
};
149150

example/65_gemm_multiply_multiply/run_gemm_multiply_multiply_wp_example.inc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
int run_gemm_example(int argc, char* argv[])
77
{
8+
using Bypass = ck::tensor_layout::BypassLayoutVerification;
9+
810
bool do_verification = true;
911
int init_method = 1;
1012
bool time_kernel = false;
@@ -64,11 +66,11 @@ int run_gemm_example(int argc, char* argv[])
6466

6567
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
6668
{
67-
return ck::HostTensorDescriptor({row, col}, {stride, 1_uz});
69+
return ck::HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
6870
}
6971
else
7072
{
71-
return ck::HostTensorDescriptor({row, col}, {1_uz, stride});
73+
return ck::HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
7274
}
7375
};
7476

example/68_gemm_add/common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
8787
config.init_method = std::stoi(argv[2]);
8888
config.time_kernel = std::stoi(argv[3]);
8989
}
90-
else if(argc == 13)
90+
else if(argc == 11)
9191
{
9292
config.do_verification = std::stoi(argv[1]);
9393
config.init_method = std::stoi(argv[2]);

example/68_gemm_add/run_gemm_add_example_wmma.inc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@
66
bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config)
77
{
88
using namespace ck::literals;
9+
using Bypass = ck::tensor_layout::BypassLayoutVerification;
910

1011
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
1112

1213
auto f_host_tensor_descriptor =
1314
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
1415
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
1516
{
16-
return HostTensorDescriptor({row, col}, {stride, 1_uz});
17+
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
1718
}
1819
else
1920
{
20-
return HostTensorDescriptor({row, col}, {1_uz, stride});
21+
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
2122
}
2223
};
2324

example/68_gemm_add/run_gemm_add_example_xdl.inc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@
66
bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config)
77
{
88
using namespace ck::literals;
9+
using Bypass = ck::tensor_layout::BypassLayoutVerification;
910

1011
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
1112

1213
auto f_host_tensor_descriptor =
1314
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
1415
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
1516
{
16-
return HostTensorDescriptor({row, col}, {stride, 1_uz});
17+
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
1718
}
1819
else
1920
{
20-
return HostTensorDescriptor({row, col}, {1_uz, stride});
21+
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
2122
}
2223
};
2324

example/69_gemm_add_relu/common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
8787
config.init_method = std::stoi(argv[2]);
8888
config.time_kernel = std::stoi(argv[3]);
8989
}
90-
else if(argc == 13)
90+
else if(argc == 11)
9191
{
9292
config.do_verification = std::stoi(argv[1]);
9393
config.init_method = std::stoi(argv[2]);

example/69_gemm_add_relu/run_gemm_add_relu_example_wmma.inc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@
66
bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config)
77
{
88
using namespace ck::literals;
9+
using Bypass = ck::tensor_layout::BypassLayoutVerification;
910

1011
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
1112

1213
auto f_host_tensor_descriptor =
1314
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
1415
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
1516
{
16-
return HostTensorDescriptor({row, col}, {stride, 1_uz});
17+
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
1718
}
1819
else
1920
{
20-
return HostTensorDescriptor({row, col}, {1_uz, stride});
21+
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
2122
}
2223
};
2324

example/69_gemm_add_relu/run_gemm_add_relu_example_xdl.inc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@
66
bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config)
77
{
88
using namespace ck::literals;
9+
using Bypass = ck::tensor_layout::BypassLayoutVerification;
910

1011
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
1112

1213
auto f_host_tensor_descriptor =
1314
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
1415
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
1516
{
16-
return HostTensorDescriptor({row, col}, {stride, 1_uz});
17+
return HostTensorDescriptor({row, col}, {stride, 1_uz}, Bypass{});
1718
}
1819
else
1920
{
20-
return HostTensorDescriptor({row, col}, {1_uz, stride});
21+
return HostTensorDescriptor({row, col}, {1_uz, stride}, Bypass{});
2122
}
2223
};
2324

0 commit comments

Comments
 (0)