diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3135,6 +3135,7 @@ def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>; def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>; def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>; +def SPV_OC_OpCopyMemory : I32EnumAttrCase<"OpCopyMemory", 63>; def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>; def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>; def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>; @@ -3264,23 +3265,23 @@ SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, - SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, - SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct, - SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, - SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, - SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, - SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, - SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, - SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, - SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar, - SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, - SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, - SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, - SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, - SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, - SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, - SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, - SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpCopyMemory, + SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, + SPV_OC_OpCompositeConstruct, SPV_OC_OpCompositeExtract, + SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, SPV_OC_OpConvertFToU, + SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, + SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, + SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, + SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, + SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, + SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, + SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, + SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, + SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, + SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, + SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, + SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, + SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -173,6 +173,58 @@ // ----- +def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> { + let summary = [{ + Copy from the memory pointed to by Source to the memory pointed to by + Target. Both operands must be non-void pointers and having the same + Type operand in their OpTypePointer type declaration. Matching Storage + Class is not required. The amount of memory copied is the size of the + type pointed to. The copied type must have a fixed size; i.e., it cannot + be, nor include, any OpTypeRuntimeArray types. + }]; + + let description = [{ + If present, any Memory Operands must begin with a memory operand + literal. If not present, it is the same as specifying the memory operand + None. Before version 1.4, at most one memory operands mask can be + provided. Starting with version 1.4 two masks can be provided, as + described in Memory Operands. If no masks or only one mask is present, + it applies to both Source and Target. If two masks are present, the + first applies to Target and cannot include MakePointerVisible, and the + second applies to Source and cannot include MakePointerAvailable. + + + + ``` + copy-memory-op ::= `spv.CopyMemory ` storage-class ssa-use + storage-class ssa-use + (`[` memory-access `]`)? + ` : ` spirv-element-type + ``` + + #### Example: + + ```mlir + %0 = spv.Variable : !spv.ptr + %1 = spv.Variable : !spv.ptr + spv.CopyMemory "Function" %0, "Function" %1 : f32 + ``` + }]; + + let arguments = (ins + SPV_AnyPtr:$target, + SPV_AnyPtr:$source, + OptionalAttr:$memory_access, + OptionalAttr:$alignment + ); + + let results = (outs); + + let verifier = [{ return verifyCopyMemory(*this); }]; +} + +// ----- + def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [InModuleScope]> { let summary = "Declare an execution mode for an entry point."; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -183,17 +183,17 @@ return parser.parseRSquare(); } -template +template static void -printMemoryAccessAttribute(LoadStoreOpTy loadStoreOp, OpAsmPrinter &printer, +printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl &elidedAttrs) { // Print optional memory access attribute. - if (auto memAccess = loadStoreOp.memory_access()) { + if (auto memAccess = memoryOp.memory_access()) { elidedAttrs.push_back(spirv::attributeName()); printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; // Print integer alignment attribute. - if (auto alignment = loadStoreOp.alignment()) { + if (auto alignment = memoryOp.alignment()) { elidedAttrs.push_back(kAlignmentAttrName); printer << ", " << alignment; } @@ -243,18 +243,18 @@ return success(); } -template -static LogicalResult verifyMemoryAccessAttribute(LoadStoreOpTy loadStoreOp) { +template +static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) { // ODS checks for attributes values. Just need to verify that if the // memory-access attribute is Aligned, then the alignment attribute must be // present. - auto *op = loadStoreOp.getOperation(); + auto *op = memoryOp.getOperation(); auto memAccessAttr = op->getAttr(spirv::attributeName()); if (!memAccessAttr) { // Alignment attribute shouldn't be present if memory access attribute is // not present. if (op->getAttr(kAlignmentAttrName)) { - return loadStoreOp.emitOpError( + return memoryOp.emitOpError( "invalid alignment specification without aligned memory access " "specification"); } @@ -265,17 +265,17 @@ auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt()); if (!memAccess) { - return loadStoreOp.emitOpError("invalid memory access specifier: ") + return memoryOp.emitOpError("invalid memory access specifier: ") << memAccessVal; } if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { if (!op->getAttr(kAlignmentAttrName)) { - return loadStoreOp.emitOpError("missing alignment value"); + return memoryOp.emitOpError("missing alignment value"); } } else { if (op->getAttr(kAlignmentAttrName)) { - return loadStoreOp.emitOpError( + return memoryOp.emitOpError( "invalid alignment specification with non-aligned memory access " "specification"); } @@ -2752,8 +2752,7 @@ static LogicalResult verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) { if (op.c().getType() != op.result().getType()) - return op.emitOpError( - "result and third operand must have the same type"); + return op.emitOpError("result and third operand must have the same type"); auto typeA = op.a().getType().cast(); auto typeB = op.b().getType().cast(); auto typeC = op.c().getType().cast(); @@ -2812,9 +2811,89 @@ "have the same size"); } } + + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.CopyMemory +//===----------------------------------------------------------------------===// + +static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) { + auto *op = copyMemory.getOperation(); + printer << spirv::CopyMemoryOp::getOperationName() << ' '; + + StringRef targetStorageClass = + stringifyStorageClass(copyMemory.target() + .getType() + .cast() + .getStorageClass()); + printer << " \"" << targetStorageClass << "\" " << copyMemory.target() + << ", "; + + StringRef sourceStorageClass = + stringifyStorageClass(copyMemory.source() + .getType() + .cast() + .getStorageClass()); + printer << " \"" << sourceStorageClass << "\" " << copyMemory.source(); + + SmallVector elidedAttrs; + printMemoryAccessAttribute(copyMemory, printer, elidedAttrs); + + printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); + + Type pointeeType = + copyMemory.target().getType().cast().getPointeeType(); + printer << " : " << pointeeType; +} + +static ParseResult parseCopyMemoryOp(OpAsmParser &parser, + OperationState &state) { + spirv::StorageClass targetStorageClass; + OpAsmParser::OperandType targetPtrInfo; + + spirv::StorageClass sourceStorageClass; + OpAsmParser::OperandType sourcePtrInfo; + + Type elementType; + + if (parseEnumStrAttr(targetStorageClass, parser) || + parser.parseOperand(targetPtrInfo) || parser.parseComma() || + parseEnumStrAttr(sourceStorageClass, parser) || + parser.parseOperand(sourcePtrInfo) || + parseMemoryAccessAttributes(parser, state) || + parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() || + parser.parseType(elementType)) { + return failure(); + } + + auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass); + auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass); + + if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) || + parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) { + return failure(); + } + return success(); } +static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) { + Type targetType = + copyMemory.target().getType().cast().getPointeeType(); + + Type sourceType = + copyMemory.source().getType().cast().getPointeeType(); + + if (targetType != sourceType) { + return copyMemory.emitOpError( + "both operands must be pointers to the same type"); + } + + return verifyMemoryAccessAttribute(copyMemory); +} + //===----------------------------------------------------------------------===// // spv.Transpose //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir @@ -57,3 +57,43 @@ spv.Return } } + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + spv.func @copy_memory_simple() "None" { + %0 = spv.Variable : !spv.ptr + %1 = spv.Variable : !spv.ptr + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} : f32 + spv.CopyMemory "Function" %0, "Function" %1 : f32 + spv.Return + } +} + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + spv.func @copy_memory_different_storage_classes(%in : !spv.ptr, Input>, %out : !spv.ptr, Output>) "None" { + // CHECK: spv.CopyMemory "Output" %{{.*}}, "Input" %{{.*}} : !spv.array<4 x f32> + spv.CopyMemory "Output" %out, "Input" %in : !spv.array<4xf32> + spv.Return + } +} + + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + spv.func @copy_memory_with_access_operands() "None" { + %0 = spv.Variable : !spv.ptr + %1 = spv.Variable : !spv.ptr + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32 + spv.CopyMemory "Function" %0, "Function" %1 ["Aligned", 4] : f32 + + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32 + spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"] : f32 + + spv.Return + } +} + diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -1244,3 +1244,38 @@ %0 = spv.Variable : !spv.ptr return } + +// ----- + +func @copy_memory_incompatible_ptrs() -> () { + %0 = spv.Variable : !spv.ptr + %1 = spv.Variable : !spv.ptr + // expected-error @+1 {{both operands must be pointers to the same type}} + "spv.CopyMemory"(%0, %1) {} : (!spv.ptr, !spv.ptr) -> () + spv.Return +} + +// ----- + +func @copy_memory_invalid_maa() -> () { + %0 = spv.Variable : !spv.ptr + %1 = spv.Variable : !spv.ptr + // expected-error @+1 {{missing alignment value}} + "spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32} : (!spv.ptr, !spv.ptr) -> () + spv.Return +} + +// ----- + +func @copy_memory_print_maa() -> () { + %0 = spv.Variable : !spv.ptr + %1 = spv.Variable : !spv.ptr + + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32 + "spv.CopyMemory"(%0, %1) {memory_access=0x0001 : i32} : (!spv.ptr, !spv.ptr) -> () + + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32 + "spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + + spv.Return +}