diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3135,6 +3135,7 @@ def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>; def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>; def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>; +def SPV_OC_OpCopyMemory : I32EnumAttrCase<"OpCopyMemory", 63>; def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>; def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; @@ -3263,14 +3264,15 @@ SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, - SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, - SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct, - SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU, - SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, - SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, - SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, - SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, - SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, + SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory, + SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, + SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, + SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, + SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, + SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, SPV_OC_OpFNegate, + SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, + SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, + SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -173,6 +173,58 @@ // ----- +def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> { + let summary = [{ + Copy from the memory pointed to by Source to the memory pointed to by + Target. Both operands must be non-void pointers and having the same + Type operand in their OpTypePointer type declaration. Matching Storage + Class is not required. The amount of memory copied is the size of the + type pointed to. The copied type must have a fixed size; i.e., it cannot + be, nor include, any OpTypeRuntimeArray types. + }]; + + let description = [{ + If present, any Memory Operands must begin with a memory operand + literal. If not present, it is the same as specifying the memory operand + None. Before version 1.4, at most one memory operands mask can be + provided. Starting with version 1.4 two masks can be provided, as + described in Memory Operands. If no masks or only one mask is present, + it applies to both Source and Target. If two masks are present, the + first applies to Target and cannot include MakePointerVisible, and the + second applies to Source and cannot include MakePointerAvailable. + + + + ``` + copy-memory-op ::= `spv.CopyMemory ` storage-class ssa-use + storage-class ssa-use + (`[` memory-access `]`)? + ` : ` spirv-element-type + ``` + + #### Example: + + ```mlir + %0 = spv.Variable : !spv.ptr + %1 = spv.Variable : !spv.ptr + spv.CopyMemory "Function" %0, "Function" %1 : f32 + ``` + }]; + + let arguments = (ins + SPV_AnyPtr:$target, + SPV_AnyPtr:$source, + OptionalAttr:$memory_access, + OptionalAttr:$alignment + ); + + let results = (outs); + + let verifier = [{ return success(); }]; +} + +// ----- + def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [InModuleScope]> { let summary = "Declare an execution mode for an entry point."; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -2812,6 +2812,68 @@ "have the same size"); } } + + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.CopyMemory +//===----------------------------------------------------------------------===// + +static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) { + printer << spirv::CopyMemoryOp::getOperationName() << ' '; + + StringRef targetStorageClass = + stringifyStorageClass(copyMemory.target() + .getType() + .cast() + .getStorageClass()); + printer << " \"" << targetStorageClass << "\" " << copyMemory.target() + << ", "; + + StringRef sourceStorageClass = + stringifyStorageClass(copyMemory.source() + .getType() + .cast() + .getStorageClass()); + printer << " \"" << sourceStorageClass << "\" " << copyMemory.source(); + + SmallVector elidedAttrs; + printMemoryAccessAttribute(copyMemory, printer, elidedAttrs); + + Type pointeeType = + copyMemory.target().getType().cast().getPointeeType(); + printer << " : " << pointeeType; +} + +static ParseResult parseCopyMemoryOp(OpAsmParser &parser, + OperationState &state) { + spirv::StorageClass targetStorageClass; + OpAsmParser::OperandType targetPtrInfo; + + spirv::StorageClass sourceStorageClass; + OpAsmParser::OperandType sourcePtrInfo; + + Type elementType; + + if (parseEnumStrAttr(targetStorageClass, parser) || + parser.parseOperand(targetPtrInfo) || parser.parseComma() || + parseEnumStrAttr(sourceStorageClass, parser) || + parser.parseOperand(sourcePtrInfo) || + parseMemoryAccessAttributes(parser, state) || + parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() || + parser.parseType(elementType)) { + return failure(); + } + + auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass); + auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass); + + if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) || + parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) { + return failure(); + } + return success(); } diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir @@ -57,3 +57,44 @@ spv.Return } } + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + spv.func @copy_memory_simple() "None" { + %0 = spv.Variable : !spv.ptr + %1 = spv.Variable : !spv.ptr + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} : f32 + spv.CopyMemory "Function" %0, "Function" %1 : f32 + spv.Return + } +} + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + spv.func @copy_memory_different_storage_classes(%in : !spv.ptr, Input>, %out : !spv.ptr, Output>) "None" { + // CHECK: spv.CopyMemory "Output" %{{.*}}, "Input" %{{.*}} : !spv.array<4 x f32> + spv.CopyMemory "Output" %out, "Input" %in : !spv.array<4xf32> + spv.Return + } +} + + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + spv.func @copy_memory_with_access_operands() "None" { + %0 = spv.Variable : !spv.ptr + %1 = spv.Variable : !spv.ptr + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32 + spv.CopyMemory "Function" %0, "Function" %1 ["Aligned", 4] : f32 + + %2 = spv.Variable : !spv.ptr + %3 = spv.Variable : !spv.ptr + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32 + spv.CopyMemory "Function" %2, "Function" %3 ["Volatile"] : f32 + spv.Return + } +} +