diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -414,6 +414,10 @@ result = rewriter.create( loc, scope, adaptor.getValue(), adaptor.getOffset()); break; + case gpu::ShuffleMode::IDX: + result = rewriter.create( + loc, scope, adaptor.getValue(), adaptor.getOffset()); + break; default: return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode"); } diff --git a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir --- a/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/shuffle.mlir @@ -46,3 +46,29 @@ } } + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> +} { + +gpu.module @kernels { + // CHECK-LABEL: spirv.func @shuffle_idx() + gpu.func @shuffle_idx() kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi: vector<3xi32>>} { + %mask = arith.constant 8 : i32 + %width = arith.constant 16 : i32 + %val = arith.constant 42.0 : f32 + + // CHECK: %[[MASK:.+]] = spirv.Constant 8 : i32 + // CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32 + // CHECK: %{{.+}} = spirv.Constant true + // CHECK: %{{.+}} = spirv.GroupNonUniformShuffle %[[VAL]], %[[MASK]] : f32, i32 + %result, %valid = gpu.shuffle idx %val, %mask, %width : f32 + gpu.return + } +} + +}