diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -27,12 +27,12 @@ // In addition to normal types arithmetic instructions can support cooperative // matrix. let arguments = (ins - SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<type>:$operand1, - SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<type>:$operand2 + SPV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1, + SPV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2 ); let results = (outs - SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<type>:$result + SPV_ScalarOrVectorOrCoopMatrixOf<type>:$result ); let assemblyFormat = "operands attr-dict `:` type($result)"; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4106,14 +4106,6 @@ AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, SPV_CoopMatrixOfType<[type]>]>; -class SPV_ScalarOrVectorOrJointMatrixOf<Type type> : - AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, - SPV_JointMatrixOfType<[type]>]>; - -class SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<Type type> : - AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, - SPV_CoopMatrixOfType<[type]>, SPV_JointMatrixOfType<[type]> ]>; - def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>; def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td @@ -23,11 +23,11 @@ !listconcat(traits, [NoSideEffect, SameOperandsAndResultShape])> { let arguments = (ins - SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<operandType>:$operand + SPV_ScalarOrVectorOrCoopMatrixOf<operandType>:$operand ); let results = (outs - SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf<resultType>:$result + SPV_ScalarOrVectorOrCoopMatrixOf<resultType>:$result ); let assemblyFormat = [{ $operand attr-dict `:` type($operand) `to` type($result) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -3921,7 +3921,9 @@ spirv::StorageClass storage = pointer.cast<spirv::PointerType>().getStorageClass(); if (storage != spirv::StorageClass::Workgroup && - storage != spirv::StorageClass::CrossWorkgroup) + storage != spirv::StorageClass::CrossWorkgroup && + storage != spirv::StorageClass::UniformConstant && + storage != spirv::StorageClass::Generic) return op->emitError("Pointer storage class must be Workgroup or " "CrossWorkgroup but provided ") << stringifyStorageClass(storage); diff --git a/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir @@ -50,65 +50,6 @@ spv.Return } -// CHECK-LABEL: @joint_matrix_add -spv.func @joint_matrix_add(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.IAdd %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return -} - -// CHECK-LABEL: @joint_matrix_sub -spv.func @joint_matrix_sub(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.ISub %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return -} - -// CHECK-LABEL: @joint_matrix_sdiv -spv.func @joint_matrix_sdiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.SDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return -} - -// CHECK-LABEL: @joint_matrix_udiv -spv.func @joint_matrix_udiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.UDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return -} - -// CHECK-LABEL: @joint_matrix_fadd -spv.func @joint_matrix_fadd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - %r = spv.FAdd %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - spv.Return -} - -// CHECK-LABEL: @joint_matrix_fsub -spv.func @joint_matrix_fsub(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - %r = spv.FSub %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - spv.Return -} - -// CHECK-LABEL: @joint_matrix_fdiv -spv.func @joint_matrix_fdiv(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - %r = spv.FDiv %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - spv.Return -} - -// ----- - -// CHECK-LABEL: @joint_matrix_access_chain -spv.func @joint_matrix_access_chain(%a : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>) -> !spv.ptr<f32, Function> "None" { - %0 = spv.Constant 0: i32 - // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32 - %1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32 - spv.ReturnValue %1 : !spv.ptr<f32, Function> -} - // ----- spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<16x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" { diff --git a/mlir/test/Target/SPIRV/joint-matrix-ops.mlir b/mlir/test/Target/SPIRV/joint-matrix-ops.mlir --- a/mlir/test/Target/SPIRV/joint-matrix-ops.mlir +++ b/mlir/test/Target/SPIRV/joint-matrix-ops.mlir @@ -42,61 +42,4 @@ %r = spv.JointMatrixMadINTEL <Subgroup> %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> spv.Return } - - // CHECK-LABEL: @joint_matrix_add - spv.func @joint_matrix_add(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.IAdd %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return - } - - // CHECK-LABEL: @joint_matrix_sub - spv.func @joint_matrix_sub(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.ISub %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return - } - - // CHECK-LABEL: @joint_matrix_sdiv - spv.func @joint_matrix_sdiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.SDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return - } - - // CHECK-LABEL: @joint_matrix_udiv - spv.func @joint_matrix_udiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.UDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return - } - - // CHECK-LABEL: @joint_matrix_fadd - spv.func @joint_matrix_fadd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - %r = spv.FAdd %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - spv.Return - } - - // CHECK-LABEL: @joint_matrix_fsub - spv.func @joint_matrix_fsub(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - %r = spv.FSub %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - spv.Return - } - - // CHECK-LABEL: @joint_matrix_fdiv - spv.func @joint_matrix_fdiv(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - %r = spv.FDiv %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - spv.Return - } - - // CHECK-LABEL: @joint_matrix_access_chain - spv.func @joint_matrix_access_chain(%a : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>) -> !spv.ptr<f32, Function> "None" { - %0 = spv.Constant 0: i32 - // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32 - %1 = spv.AccessChain %a[%0] : !spv.ptr<!spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, Function>, i32 - spv.ReturnValue %1 : !spv.ptr<f32, Function> - } }