Index: mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3252,6 +3252,8 @@ def SPV_OC_OpCooperativeMatrixStoreNV : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>; def SPV_OC_OpCooperativeMatrixMulAddNV : I32EnumAttrCase<"OpCooperativeMatrixMulAddNV", 5361>; def SPV_OC_OpCooperativeMatrixLengthNV : I32EnumAttrCase<"OpCooperativeMatrixLengthNV", 5362>; +def SPV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>; +def SPV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>; def SPV_OpcodeAttr : SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ @@ -3308,7 +3310,8 @@ SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR, SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV, - SPV_OC_OpCooperativeMatrixLengthNV + SPV_OC_OpCooperativeMatrixLengthNV, SPV_OC_OpSubgroupBlockReadINTEL, + SPV_OC_OpSubgroupBlockWriteINTEL ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! Index: mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td +++ mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td @@ -88,7 +88,6 @@ let assemblyFormat = [{ $execution_scope operands attr-dict `:` type($value) `,` type($localid) }]; - } // ----- @@ -147,4 +146,112 @@ // ----- +def SPV_SubgroupBlockReadINTELOp : SPV_Op<"SubgroupBlockReadINTEL", []> { + let summary = "See extension SPV_INTEL_subgroups"; + + let description = [{ + Reads one or more components of Result data for each invocation in the + subgroup from the specified Ptr as a block operation. + + The data is read strided, so the first value read is: + Ptr[ SubgroupLocalInvocationId ] + + and the second value read is: + Ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ] + etc. + + Result Type may be a scalar or vector type, and its component type must be + equal to the type pointed to by Ptr. + + The type of Ptr must be a pointer type, and must point to a scalar type. + + + + ``` + subgroup-block-read-INTEL-op ::= ssa-id `=` `spv.SubgroupBlockReadINTEL` + ssa_use `:` spirv-element-type + ```mlir + + #### Example: + + ``` + %0 = spv.SubgroupBlockReadINTEL %ptr : i32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_subgroups]>, + Capability<[SPV_C_SubgroupBufferBlockIOINTEL]> + ]; + + let arguments = (ins + SPV_AnyPtr:$ptr + ); + + let results = (outs + SPV_Type:$value + ); + + let builders = [ + OpBuilder<[{ + OpBuilder &builder, OperationState &state, + Value basePtr, IntegerAttr memory_access = {}, + IntegerAttr alignment = {} + }]> + ]; +} + +// ----- + +def SPV_SubgroupBlockWriteINTELOp : SPV_Op<"SubgroupBlockWriteINTEL", []> { + let summary = "See extension SPV_INTEL_subgroups"; + + let description = [{ + Writes one or more components of Data for each invocation in the subgroup + from the specified Ptr as a block operation. + + The data is written strided, so the first value is written to: + Ptr[ SubgroupLocalInvocationId ] + + and the second value written is: + Ptr[ SubgroupLocalInvocationId + SubgroupMaxSize ] + etc. + + The type of Ptr must be a pointer type, and must point to a scalar type. + + The component type of Data must be equal to the type pointed to by Ptr. + + + + ``` + subgroup-block-write-INTEL-op ::= ssa-id `=` `spv.SubgroupBlockWriteINTEL` + ssa_use `, ` ssa-use `:` spirv-element-type + ```mlir + + #### Example: + + ``` + spv.SubgroupBlockWriteINTEL %ptr, %value : i32 + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_subgroups]>, + Capability<[SPV_C_SubgroupBufferBlockIOINTEL]> + ]; + + let arguments = (ins + SPV_AnyPtr:$ptr, + SPV_Type:$value + ); + + let results = (outs); +} + +// ----- + #endif // SPIRV_GROUP_OPS Index: mlir/lib/Dialect/SPIRV/SPIRVOps.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -2025,6 +2025,96 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.SubgroupBlockReadINTEL +//===----------------------------------------------------------------------===// + +void spirv::SubgroupBlockReadINTELOp::build(OpBuilder &builder, + OperationState &state, + Value basePtr, + IntegerAttr memory_access, + IntegerAttr alignment) { + auto ptrType = basePtr.getType().cast(); + build(builder, state, ptrType.getPointeeType(), basePtr); +} + +static ParseResult parseSubgroupBlockReadINTELOp(OpAsmParser &parser, + OperationState &state) { + // Parse the storage class specification + spirv::StorageClass storageClass; + OpAsmParser::OperandType ptrInfo; + Type elementType; + if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) || + parser.parseColon() || parser.parseType(elementType)) { + return failure(); + } + + auto ptrType = spirv::PointerType::get(elementType, storageClass); + if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) { + return failure(); + } + + state.addTypes(elementType); + return success(); +} + +static void print(spirv::SubgroupBlockReadINTELOp blockReadOp, + OpAsmPrinter &printer) { + SmallVector elidedAttrs; + printer << spirv::SubgroupBlockReadINTELOp::getOperationName() << " " + << blockReadOp.ptr(); + printer << " : " << blockReadOp.getType(); +} + +static LogicalResult verify(spirv::SubgroupBlockReadINTELOp blockReadOp) { + if (failed(verifyLoadStorePtrAndValTypes(blockReadOp, blockReadOp.ptr(), + blockReadOp.value()))) { + return failure(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.SubgroupBlockWriteINTEL +//===----------------------------------------------------------------------===// + +static ParseResult parseSubgroupBlockWriteINTELOp(OpAsmParser &parser, + OperationState &state) { + // Parse the storage class specification + spirv::StorageClass storageClass; + SmallVector operandInfo; + auto loc = parser.getCurrentLocation(); + Type elementType; + if (parseEnumStrAttr(storageClass, parser) || + parser.parseOperandList(operandInfo, 2) || parser.parseColon() || + parser.parseType(elementType)) { + return failure(); + } + + auto ptrType = spirv::PointerType::get(elementType, storageClass); + if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, + state.operands)) { + return failure(); + } + return success(); +} + +static void print(spirv::SubgroupBlockWriteINTELOp blockWriteOp, + OpAsmPrinter &printer) { + SmallVector elidedAttrs; + printer << spirv::SubgroupBlockWriteINTELOp::getOperationName() << " " + << blockWriteOp.ptr() << ", " << blockWriteOp.value(); + printer << " : " << blockWriteOp.value().getType(); +} + +static LogicalResult verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp) { + if (failed(verifyLoadStorePtrAndValTypes(blockWriteOp, blockWriteOp.ptr(), + blockWriteOp.value()))) { + return failure(); + } + return success(); +} + //===----------------------------------------------------------------------===// // spv.GroupNonUniformElectOp //===----------------------------------------------------------------------===// Index: mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir =================================================================== --- mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir +++ mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir @@ -19,4 +19,16 @@ %0 = spv.GroupBroadcast "Workgroup" %value, %localid : f32, vector<3xi32> spv.ReturnValue %0: f32 } + // CHECK-LABEL: @subgroup_block_read_intel + spv.func @subgroup_block_read_intel(%ptr : !spv.ptr) -> i32 "None" { + // CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : i32 + %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32 + spv.ReturnValue %0: i32 + } + // CHECK-LABEL: @subgroup_block_write_intel + spv.func @subgroup_block_write_intel(%ptr : !spv.ptr, %value: i32) -> () "None" { + // CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : i32 + spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32 + spv.Return + } } Index: mlir/test/Dialect/SPIRV/group-ops.mlir =================================================================== --- mlir/test/Dialect/SPIRV/group-ops.mlir +++ mlir/test/Dialect/SPIRV/group-ops.mlir @@ -61,3 +61,27 @@ %0 = spv.GroupBroadcast "Subgroup" %value, %localid : f32, vector<4xi32> return %0: f32 } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.SubgroupBlockReadINTEL +//===----------------------------------------------------------------------===// + +func @subgroup_block_read_intel(%ptr : !spv.ptr) -> i32 { + // CHECK: spv.SubgroupBlockReadINTEL %{{.*}} : i32 + %0 = spv.SubgroupBlockReadINTEL "StorageBuffer" %ptr : i32 + return %0: i32 +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.SubgroupBlockWriteINTEL +//===----------------------------------------------------------------------===// + +func @subgroup_block_write_intel(%ptr : !spv.ptr, %value: i32) -> () { + // CHECK: spv.SubgroupBlockWriteINTEL %{{.*}}, %{{.*}} : i32 + spv.SubgroupBlockWriteINTEL "StorageBuffer" %ptr, %value : i32 + return +}