diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3252,6 +3252,8 @@ def SPV_OC_OpCooperativeMatrixStoreNV : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>; def SPV_OC_OpCooperativeMatrixMulAddNV : I32EnumAttrCase<"OpCooperativeMatrixMulAddNV", 5361>; def SPV_OC_OpCooperativeMatrixLengthNV : I32EnumAttrCase<"OpCooperativeMatrixLengthNV", 5362>; +def SPV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>; +def SPV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>; def SPV_OpcodeAttr : SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ @@ -3308,7 +3310,8 @@ SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR, SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV, - SPV_OC_OpCooperativeMatrixLengthNV + SPV_OC_OpCooperativeMatrixLengthNV, SPV_OC_OpSubgroupBlockReadINTEL, + SPV_OC_OpSubgroupBlockWriteINTEL ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td @@ -88,7 +88,6 @@ let assemblyFormat = [{ $execution_scope operands attr-dict `:` type($value) `,` type($localid) }]; - } // ----- @@ -147,4 +146,104 @@ // ----- +def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> { + let summary = "See extension SPV_INTEL_subgroups"; + + let description = [{ + Reads one or more components of Result data for each invocation in the + subgroup from the specified Ptr as a block operation. + + The data is read strided, so the first value read is: + Ptr[ SubgroupLocalInvocationId ] + + and the second value read is: + Ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ] + etc. + + Result Type may be a scalar or vector type, and its component type must be + equal to the type pointed to by Ptr. + + The type of Ptr must be a pointer type, and must point to a scalar type. + + + + ``` + subgroup-block-read-INTEL-op ::= ssa-id `=` `spv.SubgroupBlockReadINTEL` + storage-class ssa_use `:` spirv-element-type + ```mlir + + #### Example: + + ``` + %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_subgroups]>, + Capability<[SPV_C_SubgroupBufferBlockIOINTEL]> + ]; + + let arguments = (ins + SPV_AnyPtr:$ptr + ); + + let results = (outs + SPV_Type:$value + ); +} + +// ----- + +def SPV_SubgroupBlockWriteINTELOp : SPV_Op<"SubgroupBlockWriteINTEL", []> { + let summary = "See extension SPV_INTEL_subgroups"; + + let description = [{ + Writes one or more components of Data for each invocation in the subgroup + from the specified Ptr as a block operation. + + The data is written strided, so the first value is written to: + Ptr[ SubgroupLocalInvocationId ] + + and the second value written is: + Ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ] + etc. + + The type of Ptr must be a pointer type, and must point to a scalar type. + + The component type of Data must be equal to the type pointed to by Ptr. + + + + ``` + subgroup-block-write-INTEL-op ::= ssa-id `=` `spv.SubgroupBlockWriteINTEL` + storage-class ssa_use `,` ssa-use `:` spirv-element-type + ```mlir + + #### Example: + + ``` + spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_subgroups]>, + Capability<[SPV_C_SubgroupBufferBlockIOINTEL]> + ]; + + let arguments = (ins + SPV_AnyPtr:$ptr, + SPV_Type:$value + ); + + let results = (outs); +} + +// ----- + #endif // SPIRV_GROUP_OPS diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -468,6 +468,19 @@ return success(); } +template +static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, + Value ptr, Value val) { + auto valType = val.getType(); + if (auto valVecTy = valType.dyn_cast()) + valType = valVecTy.getElementType(); + + if (valType != ptr.getType().cast().getPointeeType()) { + return op.emitOpError("mismatch in result type and pointer type"); + } + return success(); +} + static ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state) { auto builtInName = llvm::convertToSnakeFromCamelCase( @@ -2025,6 +2038,93 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.SubgroupBlockReadINTEL +//===----------------------------------------------------------------------===// + +static ParseResult parseSubgroupBlockReadINTELOp(OpAsmParser &parser, + OperationState &state) { + // Parse the storage class specification + spirv::StorageClass storageClass; + OpAsmParser::OperandType ptrInfo; + Type elementType; + if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) || + parser.parseColon() || parser.parseType(elementType)) { + return failure(); + } + + auto ptrType = spirv::PointerType::get(elementType, storageClass); + if (auto valVecTy = elementType.dyn_cast()) + ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); + + if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) { + return failure(); + } + + state.addTypes(elementType); + return success(); +} + +static void print(spirv::SubgroupBlockReadINTELOp blockReadOp, + OpAsmPrinter &printer) { + SmallVector elidedAttrs; + printer << spirv::SubgroupBlockReadINTELOp::getOperationName() << " " + << blockReadOp.ptr(); + printer << " : " << blockReadOp.getType(); +} + +static LogicalResult verify(spirv::SubgroupBlockReadINTELOp blockReadOp) { + if (failed(verifyBlockReadWritePtrAndValTypes(blockReadOp, blockReadOp.ptr(), + blockReadOp.value()))) + return failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.SubgroupBlockWriteINTEL +//===----------------------------------------------------------------------===// + +static ParseResult parseSubgroupBlockWriteINTELOp(OpAsmParser &parser, + OperationState &state) { + // Parse the storage class specification + spirv::StorageClass storageClass; + SmallVector operandInfo; + auto loc = parser.getCurrentLocation(); + Type elementType; + if (parseEnumStrAttr(storageClass, parser) || + parser.parseOperandList(operandInfo, 2) || parser.parseColon() || + parser.parseType(elementType)) { + return failure(); + } + + auto ptrType = spirv::PointerType::get(elementType, storageClass); + if (auto valVecTy = elementType.dyn_cast()) + ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); + + if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, + state.operands)) { + return failure(); + } + return success(); +} + +static void print(spirv::SubgroupBlockWriteINTELOp blockWriteOp, + OpAsmPrinter &printer) { + SmallVector elidedAttrs; + printer << spirv::SubgroupBlockWriteINTELOp::getOperationName() << " " + << blockWriteOp.ptr() << ", " << blockWriteOp.value(); + printer << " : " << blockWriteOp.value().getType(); +} + +static LogicalResult verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp) { + if (failed(verifyBlockReadWritePtrAndValTypes( + blockWriteOp, blockWriteOp.ptr(), blockWriteOp.value()))) + return failure(); + + return success(); +} + //===----------------------------------------------------------------------===// // spv.GroupNonUniformElectOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir @@ -19,4 +19,28 @@ %0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32> spv.ReturnValue %0: f32 } + // CHECK-LABEL: @subgroup_block_read_intel + spv.func @subgroup_block_read_intel(%ptr : !spv.ptr) -> i32 "None" { + // CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : i32 + %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32 + spv.ReturnValue %0: i32 + } + // CHECK-LABEL: @subgroup_block_read_intel_vector + spv.func @subgroup_block_read_intel_vector(%ptr : !spv.ptr) -> vector<3xi32> "None" { + // CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : vector<3xi32> + %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : vector<3xi32> + spv.ReturnValue %0: vector<3xi32> + } + // CHECK-LABEL: @subgroup_block_write_intel + spv.func @subgroup_block_write_intel(%ptr : !spv.ptr, %value: i32) -> () "None" { + // CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : i32 + spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32 + spv.Return + } + // CHECK-LABEL: @subgroup_block_write_intel_vector + spv.func @subgroup_block_write_intel_vector(%ptr : !spv.ptr, %value: vector<3xi32>) -> () "None" { + // CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : vector<3xi32> + spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : vector<3xi32> + spv.Return + } } diff --git a/mlir/test/Dialect/SPIRV/group-ops.mlir b/mlir/test/Dialect/SPIRV/group-ops.mlir --- a/mlir/test/Dialect/SPIRV/group-ops.mlir +++ b/mlir/test/Dialect/SPIRV/group-ops.mlir @@ -61,3 +61,43 @@ %0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<4xi32> return %0: f32 } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.SubgroupBlockReadINTEL +//===----------------------------------------------------------------------===// + +func @subgroup_block_read_intel(%ptr : !spv.ptr) -> i32 { + // CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : i32 + %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32 + return %0: i32 +} + +// ----- + +func @subgroup_block_read_intel_vector(%ptr : !spv.ptr) -> vector<3xi32> { + // CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : vector<3xi32> + %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : vector<3xi32> + return %0: vector<3xi32> +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.SubgroupBlockWriteINTEL +//===----------------------------------------------------------------------===// + +func @subgroup_block_write_intel(%ptr : !spv.ptr, %value: i32) -> () { + // CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : i32 + spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32 + return +} + +// ----- + +func @subgroup_block_write_intel_vector(%ptr : !spv.ptr, %value: vector<3xi32>) -> () { + // CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : vector<3xi32> + spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : vector<3xi32> + return +} \ No newline at end of file