diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -399,6 +399,7 @@ def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>; def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>; def SPV_INTEL_joint_matrix : I32EnumAttrCase<"SPV_INTEL_joint_matrix", 4030>; +def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>; def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>; def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>; @@ -457,7 +458,7 @@ SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone, SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode, SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, SPV_INTEL_joint_matrix, - SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix, + SPV_INTEL_bfloat16_conversion, SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix, SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough, SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage, SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins, @@ -1413,6 +1414,12 @@ ]; } +def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> { + list availability = [ + Extension<[SPV_INTEL_bfloat16_conversion]> + ]; +} + def SPIRV_CapabilityAttr : SPIRV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [ SPIRV_C_Matrix, SPIRV_C_Addresses, SPIRV_C_Linkage, SPIRV_C_Kernel, SPIRV_C_Float16, @@ -1504,7 +1511,7 @@ SPIRV_C_UniformTexelBufferArrayNonUniformIndexing, SPIRV_C_StorageTexelBufferArrayNonUniformIndexing, SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV, - SPIRV_C_ShaderStereoViewNV, SPIRV_C_JointMatrixINTEL + SPIRV_C_ShaderStereoViewNV, SPIRV_C_JointMatrixINTEL, SPIRV_C_Bfloat16ConversionINTEL ]>; def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>; @@ -4079,6 +4086,7 @@ def SPIRV_Void : TypeAlias; def SPIRV_Bool : TypeAlias; def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>; +def SPIRV_Int16 : TypeAlias; def SPIRV_Int32 : TypeAlias; def SPIRV_Float32 : TypeAlias; def SPIRV_Float : FloatOfWidths<[16, 32, 64]>; @@ -4407,6 +4415,9 @@ def SPIRV_OC_OpJointMatrixMadINTEL : I32EnumAttrCase<"OpJointMatrixMadINTEL", 6122>; def SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL : I32EnumAttrCase<"OpJointMatrixWorkItemLengthINTEL", 6410>; +def SPIRV_OC_OpConvertFToBF16INTEL : I32EnumAttrCase<"OpConvertFToBF16INTEL", 6116>; +def SPIRV_OC_OpConvertBF16ToFINTEL : I32EnumAttrCase<"OpConvertBF16ToFINTEL", 6117>; + def SPIRV_OpcodeAttr : SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [ SPIRV_OC_OpNop, SPIRV_OC_OpUndef, SPIRV_OC_OpSourceContinued, @@ -4492,7 +4503,9 @@ SPIRV_OC_OpTypeJointMatrixINTEL, SPIRV_OC_OpJointMatrixLoadINTEL, SPIRV_OC_OpJointMatrixStoreINTEL, SPIRV_OC_OpJointMatrixMadINTEL, - SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL + SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL, + + SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td @@ -0,0 +1,130 @@ +//===- SPIRVIntelExtOps.td - Intel SPIR-V extensions ---------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the op definition spec of Intel-specific SPIR-V extensions +// These extensions are not part of Khronos specification but publicly available +// at (https://github.com/intel/llvm) +// Supported extensions +// * SPV_INTEL_bfloat16_conversion +//===----------------------------------------------------------------------===// + + +#ifndef MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS +#define MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS + +// ----- + +def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> { + let summary = "See extension SPV_INTEL_bfloat16_conversion"; + + let description = [{ + Convert value numerically from 32-bit floating point to bfloat16, + which is represented as a 16-bit unsigned integer. + + Result Type must be a scalar or vector of integer type. + The component width must be 16 bits. Bit pattern in the Result represents a bfloat16 value. + + Float Value must be a scalar or vector of floating-point type. + It must have the same number of components as Result Type. The component width must be 32 bits. + + Results are computed per component. + + ``` + integer16-scalar-vector-type ::= integer16-type | + `vector<` integer-literal `x` integer16-type `>` + ConvertFToBF16-op ::= ssa-id `=` `spirv.INTEL.ConvertFToBF16` ssa-use + `:` integer16-scalar-vector-type + ``` + + #### Example: + + ```mlir + %2 = spirv.INTEL.ConvertFToBF16 %0 : i16 + %3 = spirv.INTEL.ConvertFToBF16 %1 : vector<4xi16> + + ``` + }]; + + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_bfloat16_conversion]>, + Capability<[SPIRV_C_Bfloat16ConversionINTEL]> + ]; + + let arguments = (ins + SPIRV_ScalarOrVectorOf:$operand + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; + + let hasVerifier = 1; +} + +// ----- + +def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> { + let summary = "See extension SPV_INTEL_bfloat16_conversion"; + + let description = [{ + Interpret a 16-bit integer as bfloat16 and convert the value numerically to 32-bit floating point type. + + Result Type must be a scalar or vector of floating-point. The component width must be 32 bits. + + Bfloat16 Value must be a scalar or vector of integer type, which is interpreted as a bfloat16 type. + The type must have the same number of components as the Result Type. The component width must be 16 bits. + + Results are computed per component. + + ``` + float-scalar-vector-type ::= integer16-type | + `vector<` integer-literal `x` integer16-type `>` + ConvertFToBF16-op ::= ssa-id `=` `spirv.INTEL.ConvertBF16ToF` ssa-use + `:` float-scalar-vector-type + ``` + + #### Example: + + ```mlir + %2 = spirv.INTEL.ConvertBF16ToF %0 : f32 + %3 = spirv.INTEL.ConvertBF16ToF %1 : vector<4xf32> + + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_bfloat16_conversion]>, + Capability<[SPIRV_C_Bfloat16ConversionINTEL]> + ]; + + let arguments = (ins + SPIRV_ScalarOrVectorOf:$operand + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; + let hasVerifier = 1; +} + + +// ----- + +#endif // MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td @@ -31,6 +31,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td" +include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td" include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td" diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2201,6 +2201,46 @@ /*skipBitWidthCheck=*/true); } +//===----------------------------------------------------------------------===// +// spirv.INTELConvertBF16ToFOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::INTELConvertBF16ToFOp::verify() { + auto operandType = getOperand().getType(); + auto resultType = getResult().getType(); + // ODS checks that vector result type and vector operand type have the same + // shape. + if (auto vectorType = operandType.dyn_cast()) { + unsigned operandNumElements = vectorType.getNumElements(); + unsigned resultNumElements = resultType.cast().getNumElements(); + if (operandNumElements != resultNumElements) { + return emitOpError( + "operand and result must have same number of elements"); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// spirv.INTELConvertFToBF16Op +//===----------------------------------------------------------------------===// + +LogicalResult spirv::INTELConvertFToBF16Op::verify() { + auto operandType = getOperand().getType(); + auto resultType = getResult().getType(); + // ODS checks that vector result type and vector operand type have the same + // shape. + if (auto vectorType = operandType.dyn_cast()) { + unsigned operandNumElements = vectorType.getNumElements(); + unsigned resultNumElements = resultType.cast().getNumElements(); + if (operandNumElements != resultNumElements) { + return emitOpError( + "operand and result must have same number of elements"); + } + } + return success(); +} + //===----------------------------------------------------------------------===// // spirv.EntryPoint //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s + +spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK-LABEL: @f32_to_bf16 + spirv.func @f32_to_bf16(%arg0 : f32) "None" { + // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : f32 to i16 + %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f32 to i16 + spirv.Return + } + + // CHECK-LABEL: @f32_to_bf16_vec + spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" { + // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : vector<2xf32> to vector<2xi16> + %0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<2xi16> + spirv.Return + } + + // CHECK-LABEL: @bf16_to_f32 + spirv.func @bf16_to_f32(%arg0 : i16) "None" { + // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : i16 to f32 + %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f32 + spirv.Return + } + + // CHECK-LABEL: @bf16_to_f32_vec + spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" { + // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : vector<2xi16> to vector<2xf32> + %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<2xf32> + spirv.Return + } + + // // CHECK-LABEL: @f32_to_bf16_unsupported + // spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" { + // // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : f64 to i16 + // %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16 + // spirv.Return + // } + + // CHECK-LABEL: @bf16_to_f32_vec_unsupported + spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" { + // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : vector<2xi16> to vector<2xf32> + %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32> + spirv.Return + } +} \ No newline at end of file