diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td @@ -262,7 +262,7 @@ // ----- -def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> { +def SPV_EXTAtomicFAddOp : SPV_ExtVendorOp<"AtomicFAdd", []> { let summary = "TBD"; let description = [{ @@ -279,7 +279,7 @@ 3) store the New Value back through Pointer. - The instruction’s result is the Original Value. + The instruction's result is the Original Value. Result Type must be a floating-point type scalar. diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td @@ -15,7 +15,7 @@ // ----- -def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV", +def SPV_NVCooperativeMatrixLengthOp : SPV_NvVendorOp<"CooperativeMatrixLength", [NoSideEffect]> { let summary = "See extension SPV_NV_cooperative_matrix"; @@ -60,7 +60,7 @@ // ----- -def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> { +def SPV_NVCooperativeMatrixLoadOp : SPV_NvVendorOp<"CooperativeMatrixLoad", []> { let summary = "See extension SPV_NV_cooperative_matrix"; let description = [{ @@ -136,7 +136,7 @@ // ----- -def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV", +def SPV_NVCooperativeMatrixMulAddOp : SPV_NvVendorOp<"CooperativeMatrixMulAdd", [NoSideEffect, AllTypesMatch<["c", "result"]>]> { let summary = "See extension SPV_NV_cooperative_matrix"; @@ -210,7 +210,7 @@ // ----- -def SPV_CooperativeMatrixStoreNVOp : SPV_Op<"CooperativeMatrixStoreNV", []> { +def SPV_NVCooperativeMatrixStoreOp : SPV_NvVendorOp<"CooperativeMatrixStore", []> { let summary = "See extension SPV_NV_cooperative_matrix"; let description = [{ diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td @@ -92,7 +92,7 @@ // ----- -def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> { +def SPV_KHRSubgroupBallotOp : SPV_KhrVendorOp<"SubgroupBallot", []> { let summary = "See extension SPV_KHR_shader_ballot"; let description = [{ @@ -146,7 +146,7 @@ // ----- -def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> { +def SPV_INTELSubgroupBlockReadOp : SPV_IntelVendorOp<"SubgroupBlockRead", []> { let summary = "See extension SPV_INTEL_subgroups"; let description = [{ @@ -197,7 +197,7 @@ // ----- -def SPV_SubgroupBlockWriteINTELOp : SPV_Op<"SubgroupBlockWriteINTEL", []> { +def SPV_INTELSubgroupBlockWriteOp : SPV_IntelVendorOp<"SubgroupBlockWrite", []> { let summary = "See extension SPV_INTEL_subgroups"; let description = [{ diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td @@ -15,12 +15,12 @@ // ----- -def SPV_JointMatrixWorkItemLengthINTELOp : SPV_Op<"JointMatrixWorkItemLengthINTEL", +def SPV_INTELJointMatrixWorkItemLengthOp : SPV_IntelVendorOp<"JointMatrixWorkItemLength", [NoSideEffect]> { let summary = "See extension SPV_INTEL_joint_matrix"; let description = [{ - Return number of components owned by the current work-item in + Return number of components owned by the current work-item in a joint matrix. Result Type must be an 32-bit unsigned integer type scalar. @@ -60,7 +60,7 @@ // ----- -def SPV_JointMatrixLoadINTELOp : SPV_Op<"JointMatrixLoadINTEL", []> { +def SPV_INTELJointMatrixLoadOp : SPV_IntelVendorOp<"JointMatrixLoad", []> { let summary = "See extension SPV_INTEL_joint_matrix"; let description = [{ @@ -68,26 +68,26 @@ Result Type is the type of the loaded matrix. It must be OpTypeJointMatrixINTEL. - Pointer is the pointer to load through. It specifies start of memory region where + Pointer is the pointer to load through. It specifies start of memory region where elements of the matrix are stored and arranged according to Layout. - Stride is the number of elements in memory between beginnings of successive rows, + Stride is the number of elements in memory between beginnings of successive rows, columns (or words) in the result. It must be a scalar integer type. - Layout indicates how the values loaded from memory are arranged. It must be the + Layout indicates how the values loaded from memory are arranged. It must be the result of a constant instruction. - Scope is syncronization scope for operation on the matrix. It must be the result + Scope is syncronization scope for operation on the matrix. It must be the result of a constant instruction with scalar integer type. - If present, any Memory Operands must begin with a memory operand literal. If not + If present, any Memory Operands must begin with a memory operand literal. If not present, it is the same as specifying the memory operand None. #### Example: ```mlir - %0 = spv.JointMatrixLoadINTEL %ptr, %stride - {memory_access = #spv.memory_access} : - (!spv.ptr, i32) -> + %0 = spv.JointMatrixLoadINTEL %ptr, %stride + {memory_access = #spv.memory_access} : + (!spv.ptr, i32) -> !spv.jointmatrix<8x16xi32, ColumnMajor, Subgroup> ``` }]; @@ -119,39 +119,39 @@ // ----- -def SPV_JointMatrixMadINTELOp : SPV_Op<"JointMatrixMadINTEL", +def SPV_INTELJointMatrixMadOp : SPV_IntelVendorOp<"JointMatrixMad", [NoSideEffect, AllTypesMatch<["c", "result"]>]> { let summary = "See extension SPV_INTEL_joint_matrix"; let description = [{ - Multiply matrix A by matrix B and add matrix C to the result - of the multiplication: A*B+C. Here A is a M x K matrix, B is + Multiply matrix A by matrix B and add matrix C to the result + of the multiplication: A*B+C. Here A is a M x K matrix, B is a K x N matrix and C is a M x N matrix. - Behavior is undefined if sizes of operands do not meet the - conditions above. All operands and the Result Type must be + Behavior is undefined if sizes of operands do not meet the + conditions above. All operands and the Result Type must be OpTypeJointMatrixINTEL. - A must be a OpTypeJointMatrixINTEL whose Component Type is a - signed numerical type, Row Count equals to M and Column Count + A must be a OpTypeJointMatrixINTEL whose Component Type is a + signed numerical type, Row Count equals to M and Column Count equals to K - B must be a OpTypeJointMatrixINTEL whose Component Type is a - signed numerical type, Row Count equals to K and Column Count + B must be a OpTypeJointMatrixINTEL whose Component Type is a + signed numerical type, Row Count equals to K and Column Count equals to N - C and Result Type must be a OpTypeJointMatrixINTEL with Row + C and Result Type must be a OpTypeJointMatrixINTEL with Row Count equals to M and Column Count equals to N - Scope is syncronization scope for operation on the matrix. - It must be the result of a constant instruction with scalar + Scope is syncronization scope for operation on the matrix. + It must be the result of a constant instruction with scalar integer type. #### Example: ```mlir - %r = spv.JointMatrixMadINTEL %a, %b, %c : - !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, - !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> + %r = spv.JointMatrixMadINTEL %a, %b, %c : + !spv.jointmatrix<8x32xi8, RowMajor, Subgroup>, + !spv.jointmatrix<32x8xi8, ColumnMajor, Subgroup> -> !spv.jointmatrix<8x8xi32, RowMajor, Subgroup> ``` @@ -182,38 +182,38 @@ // ----- -def SPV_JointMatrixStoreINTELOp : SPV_Op<"JointMatrixStoreINTEL", []> { +def SPV_INTELJointMatrixStoreOp : SPV_IntelVendorOp<"JointMatrixStore", []> { let summary = "See extension SPV_INTEL_joint_matrix"; let description = [{ Store a matrix through a pointer. - Pointer is the pointer to store through. It specifies - start of memory region where elements of the matrix must + Pointer is the pointer to store through. It specifies + start of memory region where elements of the matrix must be stored and arranged according to Layout. - Object is the matrix to store. It must be + Object is the matrix to store. It must be OpTypeJointMatrixINTEL. - Stride is the number of elements in memory between beginnings - of successive rows, columns (or words) of the Object. It must + Stride is the number of elements in memory between beginnings + of successive rows, columns (or words) of the Object. It must be a scalar integer type. - Layout indicates how the values stored to memory are arranged. + Layout indicates how the values stored to memory are arranged. It must be the result of a constant instruction. - Scope is syncronization scope for operation on the matrix. - It must be the result of a constant instruction with scalar + Scope is syncronization scope for operation on the matrix. + It must be the result of a constant instruction with scalar integer type. - If present, any Memory Operands must begin with a memory operand - literal. If not present, it is the same as specifying the memory + If present, any Memory Operands must begin with a memory operand + literal. If not present, it is the same as specifying the memory operand None. #### Example: ```mlir - spv.JointMatrixStoreINTEL %ptr, %m, %stride - {memory_access = #spv.memory_access} : (!spv.ptr, + spv.JointMatrixStoreINTEL %ptr, %m, %stride + {memory_access = #spv.memory_access} : (!spv.ptr, !spv.jointmatrix<8x16xi32, RowMajor, Subgroup>, i32) ``` diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td @@ -18,7 +18,7 @@ // ----- -def SPV_AssumeTrueKHROp : SPV_Op<"AssumeTrueKHR", []> { +def SPV_KHRAssumeTrueOp : SPV_KhrVendorOp<"AssumeTrue", []> { let summary = "TBD"; let description = [{ diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1335,15 +1335,15 @@ // spv.AtomicFAddEXTOp //===----------------------------------------------------------------------===// -LogicalResult spirv::AtomicFAddEXTOp::verify() { +LogicalResult spirv::EXTAtomicFAddOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } -ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser, +ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseAtomicUpdateOp(parser, result, true); } -void spirv::AtomicFAddEXTOp::print(OpAsmPrinter &p) { +void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) { ::printAtomicUpdateOp(*this, p); } @@ -2617,7 +2617,7 @@ // spv.SubgroupBlockReadINTEL //===----------------------------------------------------------------------===// -ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser, +ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the storage class specification spirv::StorageClass storageClass; @@ -2640,11 +2640,11 @@ return success(); } -void spirv::SubgroupBlockReadINTELOp::print(OpAsmPrinter &printer) { +void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) { printer << " " << ptr() << " : " << getType(); } -LogicalResult spirv::SubgroupBlockReadINTELOp::verify() { +LogicalResult spirv::INTELSubgroupBlockReadOp::verify() { if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value()))) return failure(); @@ -2655,7 +2655,7 @@ // spv.SubgroupBlockWriteINTEL //===----------------------------------------------------------------------===// -ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser, +ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the storage class specification spirv::StorageClass storageClass; @@ -2679,11 +2679,11 @@ return success(); } -void spirv::SubgroupBlockWriteINTELOp::print(OpAsmPrinter &printer) { +void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) { printer << " " << ptr() << ", " << value() << " : " << value().getType(); } -LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() { +LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() { if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value()))) return failure(); @@ -3787,7 +3787,7 @@ // spv.CooperativeMatrixLoadNV //===----------------------------------------------------------------------===// -ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser, +ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector operandInfo; Type strideType = parser.getBuilder().getIntegerType(32); @@ -3809,7 +3809,7 @@ return success(); } -void spirv::CooperativeMatrixLoadNVOp::print(OpAsmPrinter &printer) { +void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) { printer << " " << pointer() << ", " << stride() << ", " << columnmajor(); // Print optional memory access attribute. if (auto memAccess = memory_access()) @@ -3836,7 +3836,7 @@ return success(); } -LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() { +LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() { return verifyPointerAndCoopMatrixType(*this, pointer().getType(), result().getType()); } @@ -3845,7 +3845,7 @@ // spv.CooperativeMatrixStoreNV //===----------------------------------------------------------------------===// -ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser, +ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector operandInfo; Type strideType = parser.getBuilder().getIntegerType(32); @@ -3867,7 +3867,7 @@ return success(); } -void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) { +void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) { printer << " " << pointer() << ", " << object() << ", " << stride() << ", " << columnmajor(); // Print optional memory access attribute. @@ -3876,7 +3876,7 @@ printer << " : " << pointer().getType() << ", " << getOperand(1).getType(); } -LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() { +LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() { return verifyPointerAndCoopMatrixType(*this, pointer().getType(), object().getType()); } @@ -3886,7 +3886,7 @@ //===----------------------------------------------------------------------===// static LogicalResult -verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) { +verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) { if (op.c().getType() != op.result().getType()) return op.emitOpError("result and third operand must have the same type"); auto typeA = op.a().getType().cast(); @@ -3907,7 +3907,7 @@ return success(); } -LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() { +LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() { return verifyCoopMatrixMulAdd(*this); } @@ -3934,7 +3934,7 @@ // spv.JointMatrixLoadINTEL //===----------------------------------------------------------------------===// -LogicalResult spirv::JointMatrixLoadINTELOp::verify() { +LogicalResult spirv::INTELJointMatrixLoadOp::verify() { return verifyPointerAndJointMatrixType(*this, pointer().getType(), result().getType()); } @@ -3943,7 +3943,7 @@ // spv.JointMatrixStoreINTEL //===----------------------------------------------------------------------===// -LogicalResult spirv::JointMatrixStoreINTELOp::verify() { +LogicalResult spirv::INTELJointMatrixStoreOp::verify() { return verifyPointerAndJointMatrixType(*this, pointer().getType(), object().getType()); } @@ -3952,7 +3952,7 @@ // spv.JointMatrixMadINTEL //===----------------------------------------------------------------------===// -static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) { +static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) { if (op.c().getType() != op.result().getType()) return op.emitOpError("result and third operand must have the same type"); auto typeA = op.a().getType().cast(); @@ -3973,7 +3973,7 @@ return success(); } -LogicalResult spirv::JointMatrixMadINTELOp::verify() { +LogicalResult spirv::INTELJointMatrixMadOp::verify() { return verifyJointMatrixMad(*this); } diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -240,7 +240,7 @@ PatternRewriter &rewriter) const { Value predicate = op->getOperand(0); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), predicate); return success(); } diff --git a/mlir/utils/spirv/define_inst.sh b/mlir/utils/spirv/define_inst.sh --- a/mlir/utils/spirv/define_inst.sh +++ b/mlir/utils/spirv/define_inst.sh @@ -23,13 +23,17 @@ baseclass=$2 case $baseclass in - Op | ArithmeticBinaryOp | ArithmeticUnaryOp | LogicalBinaryOp | LogicalUnaryOp | CastOp | ControlFlowOp | StructureOp | AtomicUpdateOp | AtomicUpdateWithValueOp) + Op | ArithmeticBinaryOp | ArithmeticUnaryOp \ + | LogicalBinaryOp | LogicalUnaryOp \ + | CastOp | ControlFlowOp | StructureOp \ + | AtomicUpdateOp | AtomicUpdateWithValueOp \ + | KhrVendorOp | ExtVendorOp | IntelVendorOp | NvVendorOp ) ;; *) echo "Usage : " $0 " ()*" echo " is the file name of MLIR SPIR-V op definitions spec" echo " must be one of " \ - "(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp)" + "(Op|ArithmeticBinaryOp|ArithmeticUnaryOp|LogicalBinaryOp|LogicalUnaryOp|CastOp|ControlFlowOp|StructureOp|AtomicUpdateOp|KhrVendorOp|ExtVendorOp|IntelVendorOp|NvVendorOp)" exit 1; ;; esac diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -730,15 +730,19 @@ '{{\n let summary = {summary};\n\n let description = ' '[{{\n{description}}}];{availability}\n') else: - fmt_str = ('def SPV_{opname_src}Op : ' - 'SPV_{inst_category}<"{opname_src}"{category_args}[{traits}]> ' + fmt_str = ('def SPV_{vendor_name}{opname_src}Op : ' + 'SPV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> ' '{{\n let summary = {summary};\n\n let description = ' '[{{\n{description}}}];{availability}\n') + vendor_name = '' inst_category = existing_info.get('inst_category', 'Op') if inst_category == 'Op': fmt_str +='\n let arguments = (ins{args});\n\n'\ ' let results = (outs{results});\n' + elif inst_category.endswith('VendorOp'): + vendor_name = inst_category.split('VendorOp')[0].upper() + assert len(vendor_name) != 0, 'Invalid instruction category' fmt_str +='{extras}'\ '}}\n' @@ -746,6 +750,9 @@ opname_src = instruction['opname'] if opname.startswith('Op'): opname_src = opname_src[2:] + if len(vendor_name) > 0: + assert opname_src.endswith(vendor_name), "op name does not match the instruction category" + opname_src = opname_src[:-len(vendor_name)] category_args = existing_info.get('category_args', '') @@ -759,7 +766,7 @@ # Format summary. If the summary can fit in the same line, we print it out # as a "-quoted string; otherwise, wrap the lines using "[{...}]". - summary = summary.strip(); + summary = summary.strip() if len(summary) + len(' let summary = "";') <= 80: summary = '"{}"'.format(summary) else: @@ -815,6 +822,7 @@ opcode=instruction['opcode'], category_args=category_args, inst_category=inst_category, + vendor_name=vendor_name, traits=existing_info.get('traits', ''), summary=summary, description=description,