Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -27,12 +27,12 @@ // In addition to normal types arithmetic instructions can support cooperative // matrix. let arguments = (ins - SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf:$operand1, - SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf:$operand2 + SPV_ScalarOrVectorOrCoopMatrixOf:$operand1, + SPV_ScalarOrVectorOrCoopMatrixOf:$operand2 ); let results = (outs - SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf:$result + SPV_ScalarOrVectorOrCoopMatrixOf:$result ); let assemblyFormat = "operands attr-dict `:` type($result)"; } Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4106,14 +4106,6 @@ AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, SPV_CoopMatrixOfType<[type]>]>; -class SPV_ScalarOrVectorOrJointMatrixOf : - AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, - SPV_JointMatrixOfType<[type]>]>; - -class SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf : - AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, - SPV_CoopMatrixOfType<[type]>, SPV_JointMatrixOfType<[type]> ]>; - def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>; def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>; Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td @@ -23,11 +23,11 @@ !listconcat(traits, [NoSideEffect, SameOperandsAndResultShape])> { let arguments = (ins - SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf:$operand + SPV_ScalarOrVectorOrCoopMatrixOf:$operand ); let results = (outs - SPV_ScalarOrVectorOrCoopMatrixOfOrJointMatrixOf:$result + SPV_ScalarOrVectorOrCoopMatrixOf:$result ); let assemblyFormat = [{ $operand attr-dict `:` type($operand) `to` type($result) Index: mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -3921,7 +3921,9 @@ spirv::StorageClass storage = pointer.cast().getStorageClass(); if (storage != spirv::StorageClass::Workgroup && - storage != spirv::StorageClass::CrossWorkgroup) + storage != spirv::StorageClass::CrossWorkgroup && + storage != spirv::StorageClass::UniformConstant && + storage != spirv::StorageClass::Generic) return op->emitError("Pointer storage class must be Workgroup or " "CrossWorkgroup but provided ") << stringifyStorageClass(storage); Index: mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir =================================================================== --- mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir +++ mlir/test/Dialect/SPIRV/IR/joint-matrix-ops.mlir @@ -50,65 +50,6 @@ spv.Return } -// CHECK-LABEL: @joint_matrix_add -spv.func @joint_matrix_add(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.IAdd %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return -} - -// CHECK-LABEL: @joint_matrix_sub -spv.func @joint_matrix_sub(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.ISub %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return -} - -// CHECK-LABEL: @joint_matrix_sdiv -spv.func @joint_matrix_sdiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.SDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return -} - -// CHECK-LABEL: @joint_matrix_udiv -spv.func @joint_matrix_udiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.UDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return -} - -// CHECK-LABEL: @joint_matrix_fadd -spv.func @joint_matrix_fadd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - %r = spv.FAdd %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - spv.Return -} - -// CHECK-LABEL: @joint_matrix_fsub -spv.func @joint_matrix_fsub(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - %r = spv.FSub %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - spv.Return -} - -// CHECK-LABEL: @joint_matrix_fdiv -spv.func @joint_matrix_fdiv(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - %r = spv.FDiv %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - spv.Return -} - -// ----- - -// CHECK-LABEL: @joint_matrix_access_chain -spv.func @joint_matrix_access_chain(%a : !spv.ptr, Function>) -> !spv.ptr "None" { - %0 = spv.Constant 0: i32 - // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, Function>, i32 - %1 = spv.AccessChain %a[%0] : !spv.ptr, Function>, i32 - spv.ReturnValue %1 : !spv.ptr -} - // ----- spv.func @joint_matrix_muladd(%a : !spv.jointmatrix<16x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<16x8xi32, RowMajor, Subgroup>, %c : !spv.jointmatrix<8x8xi32, RowMajor, Subgroup>) "None" { Index: mlir/test/Target/SPIRV/joint-matrix-ops.mlir =================================================================== --- mlir/test/Target/SPIRV/joint-matrix-ops.mlir +++ mlir/test/Target/SPIRV/joint-matrix-ops.mlir @@ -42,61 +42,4 @@ %r = spv.JointMatrixMadINTEL %a, %b, %c : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, !spv.jointmatrix<16x8xi32, RowMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> spv.Return } - - // CHECK-LABEL: @joint_matrix_add - spv.func @joint_matrix_add(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.IAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.IAdd %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return - } - - // CHECK-LABEL: @joint_matrix_sub - spv.func @joint_matrix_sub(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.ISub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.ISub %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return - } - - // CHECK-LABEL: @joint_matrix_sdiv - spv.func @joint_matrix_sdiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.SDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.SDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return - } - - // CHECK-LABEL: @joint_matrix_udiv - spv.func @joint_matrix_udiv(%a : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.UDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - %r = spv.UDiv %a, %b : !spv.jointmatrix<8x16xi32, RowMajor, Subgroup> - spv.Return - } - - // CHECK-LABEL: @joint_matrix_fadd - spv.func @joint_matrix_fadd(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.FAdd {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - %r = spv.FAdd %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - spv.Return - } - - // CHECK-LABEL: @joint_matrix_fsub - spv.func @joint_matrix_fsub(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.FSub {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - %r = spv.FSub %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - spv.Return - } - - // CHECK-LABEL: @joint_matrix_fdiv - spv.func @joint_matrix_fdiv(%a : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup>) "None" { - // CHECK: {{%.*}} = spv.FDiv {{%.*}}, {{%.*}} : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - %r = spv.FDiv %a, %b : !spv.jointmatrix<8x16xf32, RowMajor, Subgroup> - spv.Return - } - - // CHECK-LABEL: @joint_matrix_access_chain - spv.func @joint_matrix_access_chain(%a : !spv.ptr, Function>) -> !spv.ptr "None" { - %0 = spv.Constant 0: i32 - // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, Function>, i32 - %1 = spv.AccessChain %a[%0] : !spv.ptr, Function>, i32 - spv.ReturnValue %1 : !spv.ptr - } }