diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -198,7 +198,7 @@ ``` copy-memory-op ::= `spv.CopyMemory ` storage-class ssa-use storage-class ssa-use - (`[` memory-access `]`)? + (`[` memory-access `]` (`, [` memory-access `]`)?)? ` : ` spirv-element-type ``` @@ -215,12 +215,16 @@ SPV_AnyPtr:$target, SPV_AnyPtr:$source, OptionalAttr:$memory_access, - OptionalAttr:$alignment + OptionalAttr:$alignment, + OptionalAttr:$source_memory_access, + OptionalAttr:$source_alignment ); let results = (outs); let verifier = [{ return verifyCopyMemory(*this); }]; + + let autogenSerialization = 0; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -28,7 +28,11 @@ using namespace mlir; // TODO: generate these strings using ODS. +static constexpr const char kMemoryAccessAttrName[] = "memory_access"; +static constexpr const char kSourceMemoryAccessAttrName[] = + "source_memory_access"; static constexpr const char kAlignmentAttrName[] = "alignment"; +static constexpr const char kSourceAlignmentAttrName[] = "source_alignment"; static constexpr const char kBranchWeightAttrName[] = "branch_weights"; static constexpr const char kCallee[] = "callee"; static constexpr const char kClusterSize[] = "cluster_size"; @@ -157,6 +161,14 @@ return success(); } +/// Parses optional memory access attributes attached to a memory access +/// operand/pointer. Specifically, parses the following syntax: +/// (`[` memory-access `]`)? +/// where: +/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` +/// integer-literal | `"NonTemporal"` +template static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state) { // Parse an optional list of attributes staring with '[' @@ -166,7 +178,7 @@ } spirv::MemoryAccess memoryAccessAttr; - if (parseEnumStrAttr(memoryAccessAttr, parser, state)) { + if (parseEnumStrAttr(memoryAccessAttr, parser, state, memoryAccessAttrName)) { return failure(); } @@ -175,7 +187,7 @@ Attribute alignmentAttr; Type i32Type = parser.getBuilder().getIntegerType(32); if (parser.parseComma() || - parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName, + parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName, state.attributes)) { return failure(); } @@ -183,19 +195,32 @@ return parser.parseRSquare(); } -template -static void -printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, - SmallVectorImpl &elidedAttrs) { +template +static void printMemoryAccessAttribute( + MemoryOpTy memoryOp, OpAsmPrinter &printer, + SmallVectorImpl &elidedAttrs, + Optional memoryAccessAtrrValue = None, + Optional alignmentAttrValue = None) { // Print optional memory access attribute. - if (auto memAccess = memoryOp.memory_access()) { - elidedAttrs.push_back(spirv::attributeName()); + if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue + : memoryOp.memory_access())) { + elidedAttrs.push_back(memoryAccessAttrName); + + if (!first) + printer << ", "; + printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; - // Print integer alignment attribute. - if (auto alignment = memoryOp.alignment()) { - elidedAttrs.push_back(kAlignmentAttrName); - printer << ", " << alignment; + if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { + // Print integer alignment attribute. + if (auto alignment = (alignmentAttrValue ? alignmentAttrValue + : memoryOp.alignment())) { + elidedAttrs.push_back(alignmentAttrName); + printer << ", " << alignment; + } } printer << "]"; } @@ -243,17 +268,19 @@ return success(); } -template +template static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) { // ODS checks for attributes values. Just need to verify that if the // memory-access attribute is Aligned, then the alignment attribute must be // present. auto *op = memoryOp.getOperation(); - auto memAccessAttr = op->getAttr(spirv::attributeName()); + auto memAccessAttr = op->getAttr(memoryAccessAttrName); if (!memAccessAttr) { // Alignment attribute shouldn't be present if memory access attribute is // not present. - if (op->getAttr(kAlignmentAttrName)) { + if (op->getAttr(alignmentAttrName)) { return memoryOp.emitOpError( "invalid alignment specification without aligned memory access " "specification"); @@ -270,11 +297,11 @@ } if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { - if (!op->getAttr(kAlignmentAttrName)) { + if (!op->getAttr(alignmentAttrName)) { return memoryOp.emitOpError("missing alignment value"); } } else { - if (op->getAttr(kAlignmentAttrName)) { + if (op->getAttr(alignmentAttrName)) { return memoryOp.emitOpError( "invalid alignment specification with non-aligned memory access " "specification"); @@ -2832,6 +2859,10 @@ SmallVector elidedAttrs; printMemoryAccessAttribute(copyMemory, printer, elidedAttrs); + printMemoryAccessAttribute( + copyMemory, printer, elidedAttrs, copyMemory.source_memory_access(), + copyMemory.source_alignment()); printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); @@ -2854,12 +2885,24 @@ parser.parseOperand(targetPtrInfo) || parser.parseComma() || parseEnumStrAttr(sourceStorageClass, parser) || parser.parseOperand(sourcePtrInfo) || - parseMemoryAccessAttributes(parser, state) || - parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() || - parser.parseType(elementType)) { + parseMemoryAccessAttributes(parser, state)) { return failure(); } + if (!parser.parseOptionalComma()) { + // Parse 2nd memory access attributes. + if (parseMemoryAccessAttributes(parser, state)) { + return failure(); + } + } + + if (parser.parseColon() || parser.parseType(elementType)) + return failure(); + + if (parser.parseOptionalAttrDict(state.attributes)) + return failure(); + auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass); auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass); @@ -2883,7 +2926,21 @@ "both operands must be pointers to the same type"); } - return verifyMemoryAccessAttribute(copyMemory); + if (failed(verifyMemoryAccessAttribute(copyMemory))) { + return failure(); + } + + // TODO - According to the spec: + // + // If two masks are present, the first applies to Target and cannot include + // MakePointerVisible, and the second applies to Source and cannot include + // MakePointerAvailable. + // + // Add such verification here. + + return verifyMemoryAccessAttribute(copyMemory); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -2511,6 +2511,76 @@ return success(); } +template <> +LogicalResult +Deserializer::processOp(ArrayRef words) { + SmallVector resultTypes; + size_t wordIndex = 0; + SmallVector operands; + SmallVector attributes; + + if (wordIndex < words.size()) { + auto arg = getValue(words[wordIndex]); + + if (!arg) { + return emitError(unknownLoc, "unknown result : ") + << words[wordIndex]; + } + + operands.push_back(arg); + wordIndex++; + } + + if (wordIndex < words.size()) { + auto arg = getValue(words[wordIndex]); + + if (!arg) { + return emitError(unknownLoc, "unknown result : ") + << words[wordIndex]; + } + + operands.push_back(arg); + wordIndex++; + } + + bool isAlignedAttr = false; + + if (wordIndex < words.size()) { + auto attrValue = words[wordIndex++]; + attributes.push_back(opBuilder.getNamedAttr( + "memory_access", opBuilder.getI32IntegerAttr(attrValue))); + isAlignedAttr = (attrValue == 2); + } + + if (isAlignedAttr && wordIndex < words.size()) { + attributes.push_back(opBuilder.getNamedAttr( + "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); + } + + if (wordIndex < words.size()) { + attributes.push_back(opBuilder.getNamedAttr( + "source_memory_access", + opBuilder.getI32IntegerAttr(words[wordIndex++]))); + } + + if (wordIndex < words.size()) { + attributes.push_back(opBuilder.getNamedAttr( + "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); + } + + if (wordIndex != words.size()) { + return emitError(unknownLoc, + "found more operands than expected when deserializing " + "spirv::CopyMemoryOp, only ") + << wordIndex << " of " << words.size() << " processed"; + } + + Location loc = createFileLineColLoc(opBuilder); + opBuilder.create(loc, resultTypes, operands, attributes); + + return success(); +} + // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and // various Deserializer::processOp<...>() specializations. #define GET_DESERIALIZATION_FNS diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -364,7 +364,8 @@ /// Method to serialize an operation in the SPIR-V dialect that is a mirror of /// an instruction in the SPIR-V spec. This is auto generated if hasOpcode == /// 1 and autogenSerialization == 1 in ODS. - template LogicalResult processOp(OpTy op) { + template + LogicalResult processOp(OpTy op) { return op.emitError("unsupported op serialization"); } @@ -1904,6 +1905,51 @@ operands); } +template <> +LogicalResult +Serializer::processOp(spirv::CopyMemoryOp op) { + SmallVector operands; + SmallVector elidedAttrs; + + for (Value operand : op.getOperation()->getOperands()) { + auto id = getValueID(operand); + assert(id && "use before def!"); + operands.push_back(id); + } + + if (auto attr = op.getAttr("memory_access")) { + operands.push_back(static_cast( + attr.cast().getValue().getZExtValue())); + } + + elidedAttrs.push_back("memory_access"); + + if (auto attr = op.getAttr("alignment")) { + operands.push_back(static_cast( + attr.cast().getValue().getZExtValue())); + } + + elidedAttrs.push_back("alignment"); + + if (auto attr = op.getAttr("source_memory_access")) { + operands.push_back(static_cast( + attr.cast().getValue().getZExtValue())); + } + + elidedAttrs.push_back("source_memory_access"); + + if (auto attr = op.getAttr("source_alignment")) { + operands.push_back(static_cast( + attr.cast().getValue().getZExtValue())); + } + + elidedAttrs.push_back("source_alignment"); + emitDebugLine(functionBody, op.getLoc()); + encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); + + return success(); +} + // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and // various Serializer::processOp<...>() specializations. #define GET_SERIALIZATION_FNS diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir @@ -93,6 +93,18 @@ // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32 spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"] : f32 + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"], ["Volatile"] : f32 + spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"], ["Volatile"] : f32 + + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Volatile"] : f32 + spv.CopyMemory "Function" %0, "Function" %1 ["Aligned", 4], ["Volatile"] : f32 + + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"], ["Aligned", 4] : f32 + spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"], ["Aligned", 4] : f32 + + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 8], ["Aligned", 4] : f32 + spv.CopyMemory "Function" %0, "Function" %1 ["Aligned", 8], ["Aligned", 4] : f32 + spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -1247,7 +1247,7 @@ // ----- -func @copy_memory_incompatible_ptrs() -> () { +func @copy_memory_incompatible_ptrs() { %0 = spv.Variable : !spv.ptr %1 = spv.Variable : !spv.ptr // expected-error @+1 {{both operands must be pointers to the same type}} @@ -1257,7 +1257,7 @@ // ----- -func @copy_memory_invalid_maa() -> () { +func @copy_memory_invalid_maa() { %0 = spv.Variable : !spv.ptr %1 = spv.Variable : !spv.ptr // expected-error @+1 {{missing alignment value}} @@ -1267,7 +1267,27 @@ // ----- -func @copy_memory_print_maa() -> () { +func @copy_memory_invalid_source_maa() { + %0 = spv.Variable : !spv.ptr + %1 = spv.Variable : !spv.ptr + // expected-error @+1 {{invalid alignment specification with non-aligned memory access specification}} + "spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + spv.Return +} + +// ----- + +func @copy_memory_invalid_source_maa2() { + %0 = spv.Variable : !spv.ptr + %1 = spv.Variable : !spv.ptr + // expected-error @+1 {{missing alignment value}} + "spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + spv.Return +} + +// ----- + +func @copy_memory_print_maa() { %0 = spv.Variable : !spv.ptr %1 = spv.Variable : !spv.ptr @@ -1277,5 +1297,11 @@ // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32 "spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Volatile"] : f32 + "spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Aligned", 8] : f32 + "spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + spv.Return }