diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -198,7 +198,7 @@ ``` copy-memory-op ::= `spv.CopyMemory ` storage-class ssa-use storage-class ssa-use - (`[` memory-access `]`)? + (`[` memory-access `]` (`, [` memory-access `]`)?)? ` : ` spirv-element-type ``` @@ -215,12 +215,16 @@ SPV_AnyPtr:$target, SPV_AnyPtr:$source, OptionalAttr:$memory_access, - OptionalAttr:$alignment + OptionalAttr:$alignment, + OptionalAttr:$source_memory_access, + OptionalAttr:$source_alignment ); let results = (outs); let verifier = [{ return verifyCopyMemory(*this); }]; + + let autogenSerialization = 0; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -28,7 +28,11 @@ using namespace mlir; // TODO: generate these strings using ODS. +static constexpr const char kMemoryAccessAttrName[] = "memory_access"; +static constexpr const char kSourceMemoryAccessAttrName[] = + "source_memory_access"; static constexpr const char kAlignmentAttrName[] = "alignment"; +static constexpr const char kSourceAlignmentAttrName[] = "source_alignment"; static constexpr const char kBranchWeightAttrName[] = "branch_weights"; static constexpr const char kCallee[] = "callee"; static constexpr const char kClusterSize[] = "cluster_size"; @@ -157,6 +161,12 @@ return success(); } +/// Parses optional memory access attributes attached to a memory access +/// operand/pointer. Specifically, parses the following syntax: +/// (`[` memory-access `]`)? +/// where: +/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` +/// integer-literal | `"NonTemporal"` static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state) { // Parse an optional list of attributes staring with '[' @@ -166,7 +176,8 @@ } spirv::MemoryAccess memoryAccessAttr; - if (parseEnumStrAttr(memoryAccessAttr, parser, state)) { + if (parseEnumStrAttr(memoryAccessAttr, parser, state, + kMemoryAccessAttrName)) { return failure(); } @@ -183,19 +194,90 @@ return parser.parseRSquare(); } +// TODO Make sure to merge this and the previous function into one template +// parameterized by memroy access attribute name and alignment. Doing so now +// results in VS2017 in producing an internal error (at the call site) that's +// not detailed enough to understand what is happenning. +static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, + OperationState &state) { + // Parse an optional list of attributes staring with '[' + if (parser.parseOptionalLSquare()) { + // Nothing to do + return success(); + } + + spirv::MemoryAccess memoryAccessAttr; + if (parseEnumStrAttr(memoryAccessAttr, parser, state, + kSourceMemoryAccessAttrName)) { + return failure(); + } + + if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) { + // Parse integer attribute for alignment. + Attribute alignmentAttr; + Type i32Type = parser.getBuilder().getIntegerType(32); + if (parser.parseComma() || + parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName, + state.attributes)) { + return failure(); + } + } + return parser.parseRSquare(); +} + template -static void -printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, - SmallVectorImpl &elidedAttrs) { +static void printMemoryAccessAttribute( + MemoryOpTy memoryOp, OpAsmPrinter &printer, + SmallVectorImpl &elidedAttrs, + Optional memoryAccessAtrrValue = None, + Optional alignmentAttrValue = None) { // Print optional memory access attribute. - if (auto memAccess = memoryOp.memory_access()) { - elidedAttrs.push_back(spirv::attributeName()); + if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue + : memoryOp.memory_access())) { + elidedAttrs.push_back(kMemoryAccessAttrName); + printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; - // Print integer alignment attribute. - if (auto alignment = memoryOp.alignment()) { - elidedAttrs.push_back(kAlignmentAttrName); - printer << ", " << alignment; + if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { + // Print integer alignment attribute. + if (auto alignment = (alignmentAttrValue ? alignmentAttrValue + : memoryOp.alignment())) { + elidedAttrs.push_back(kAlignmentAttrName); + printer << ", " << alignment; + } + } + printer << "]"; + } + elidedAttrs.push_back(spirv::attributeName()); +} + +// TODO Make sure to merge this and the previous function into one template +// parameterized by memroy access attribute name and alignment. Doing so now +// results in VS2017 in producing an internal error (at the call site) that's +// not detailed enough to understand what is happenning. +template +static void printSourceMemoryAccessAttribute( + MemoryOpTy memoryOp, OpAsmPrinter &printer, + SmallVectorImpl &elidedAttrs, + Optional memoryAccessAtrrValue = None, + Optional alignmentAttrValue = None) { + + printer << ", "; + + // Print optional memory access attribute. + if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue + : memoryOp.memory_access())) { + elidedAttrs.push_back(kSourceMemoryAccessAttrName); + + printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; + + if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { + // Print integer alignment attribute. + if (auto alignment = (alignmentAttrValue ? alignmentAttrValue + : memoryOp.alignment())) { + elidedAttrs.push_back(kSourceAlignmentAttrName); + printer << ", " << alignment; + } } printer << "]"; } @@ -249,7 +331,7 @@ // memory-access attribute is Aligned, then the alignment attribute must be // present. auto *op = memoryOp.getOperation(); - auto memAccessAttr = op->getAttr(spirv::attributeName()); + auto memAccessAttr = op->getAttr(kMemoryAccessAttrName); if (!memAccessAttr) { // Alignment attribute shouldn't be present if memory access attribute is // not present. @@ -283,6 +365,50 @@ return success(); } +// TODO Make sure to merge this and the previous function into one template +// parameterized by memroy access attribute name and alignment. Doing so now +// results in VS2017 in producing an internal error (at the call site) that's +// not detailed enough to understand what is happenning. +template +static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) { + // ODS checks for attributes values. Just need to verify that if the + // memory-access attribute is Aligned, then the alignment attribute must be + // present. + auto *op = memoryOp.getOperation(); + auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName); + if (!memAccessAttr) { + // Alignment attribute shouldn't be present if memory access attribute is + // not present. + if (op->getAttr(kSourceAlignmentAttrName)) { + return memoryOp.emitOpError( + "invalid alignment specification without aligned memory access " + "specification"); + } + return success(); + } + + auto memAccessVal = memAccessAttr.template cast(); + auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt()); + + if (!memAccess) { + return memoryOp.emitOpError("invalid memory access specifier: ") + << memAccessVal; + } + + if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { + if (!op->getAttr(kSourceAlignmentAttrName)) { + return memoryOp.emitOpError("missing alignment value"); + } + } else { + if (op->getAttr(kSourceAlignmentAttrName)) { + return memoryOp.emitOpError( + "invalid alignment specification with non-aligned memory access " + "specification"); + } + } + return success(); +} + template static LogicalResult verifyMemorySemantics(BarrierOp op) { // According to the SPIR-V specification: @@ -2832,6 +2958,9 @@ SmallVector elidedAttrs; printMemoryAccessAttribute(copyMemory, printer, elidedAttrs); + printSourceMemoryAccessAttribute(copyMemory, printer, elidedAttrs, + copyMemory.source_memory_access(), + copyMemory.source_alignment()); printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); @@ -2854,12 +2983,23 @@ parser.parseOperand(targetPtrInfo) || parser.parseComma() || parseEnumStrAttr(sourceStorageClass, parser) || parser.parseOperand(sourcePtrInfo) || - parseMemoryAccessAttributes(parser, state) || - parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() || - parser.parseType(elementType)) { + parseMemoryAccessAttributes(parser, state)) { return failure(); } + if (!parser.parseOptionalComma()) { + // Parse 2nd memory access attributes. + if (parseSourceMemoryAccessAttributes(parser, state)) { + return failure(); + } + } + + if (parser.parseColon() || parser.parseType(elementType)) + return failure(); + + if (parser.parseOptionalAttrDict(state.attributes)) + return failure(); + auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass); auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass); @@ -2883,7 +3023,19 @@ "both operands must be pointers to the same type"); } - return verifyMemoryAccessAttribute(copyMemory); + if (failed(verifyMemoryAccessAttribute(copyMemory))) { + return failure(); + } + + // TODO - According to the spec: + // + // If two masks are present, the first applies to Target and cannot include + // MakePointerVisible, and the second applies to Source and cannot include + // MakePointerAvailable. + // + // Add such verification here. + + return verifySourceMemoryAccessAttribute(copyMemory); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -2511,6 +2511,76 @@ return success(); } +template <> +LogicalResult +Deserializer::processOp(ArrayRef words) { + SmallVector resultTypes; + size_t wordIndex = 0; + SmallVector operands; + SmallVector attributes; + + if (wordIndex < words.size()) { + auto arg = getValue(words[wordIndex]); + + if (!arg) { + return emitError(unknownLoc, "unknown result : ") + << words[wordIndex]; + } + + operands.push_back(arg); + wordIndex++; + } + + if (wordIndex < words.size()) { + auto arg = getValue(words[wordIndex]); + + if (!arg) { + return emitError(unknownLoc, "unknown result : ") + << words[wordIndex]; + } + + operands.push_back(arg); + wordIndex++; + } + + bool isAlignedAttr = false; + + if (wordIndex < words.size()) { + auto attrValue = words[wordIndex++]; + attributes.push_back(opBuilder.getNamedAttr( + "memory_access", opBuilder.getI32IntegerAttr(attrValue))); + isAlignedAttr = (attrValue == 2); + } + + if (isAlignedAttr && wordIndex < words.size()) { + attributes.push_back(opBuilder.getNamedAttr( + "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); + } + + if (wordIndex < words.size()) { + attributes.push_back(opBuilder.getNamedAttr( + "source_memory_access", + opBuilder.getI32IntegerAttr(words[wordIndex++]))); + } + + if (wordIndex < words.size()) { + attributes.push_back(opBuilder.getNamedAttr( + "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); + } + + if (wordIndex != words.size()) { + return emitError(unknownLoc, + "found more operands than expected when deserializing " + "spirv::CopyMemoryOp, only ") + << wordIndex << " of " << words.size() << " processed"; + } + + Location loc = createFileLineColLoc(opBuilder); + opBuilder.create(loc, resultTypes, operands, attributes); + + return success(); +} + // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and // various Deserializer::processOp<...>() specializations. #define GET_DESERIALIZATION_FNS diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -364,7 +364,8 @@ /// Method to serialize an operation in the SPIR-V dialect that is a mirror of /// an instruction in the SPIR-V spec. This is auto generated if hasOpcode == /// 1 and autogenSerialization == 1 in ODS. - template LogicalResult processOp(OpTy op) { + template + LogicalResult processOp(OpTy op) { return op.emitError("unsupported op serialization"); } @@ -1904,6 +1905,51 @@ operands); } +template <> +LogicalResult +Serializer::processOp(spirv::CopyMemoryOp op) { + SmallVector operands; + SmallVector elidedAttrs; + + for (Value operand : op.getOperation()->getOperands()) { + auto id = getValueID(operand); + assert(id && "use before def!"); + operands.push_back(id); + } + + if (auto attr = op.getAttr("memory_access")) { + operands.push_back(static_cast( + attr.cast().getValue().getZExtValue())); + } + + elidedAttrs.push_back("memory_access"); + + if (auto attr = op.getAttr("alignment")) { + operands.push_back(static_cast( + attr.cast().getValue().getZExtValue())); + } + + elidedAttrs.push_back("alignment"); + + if (auto attr = op.getAttr("source_memory_access")) { + operands.push_back(static_cast( + attr.cast().getValue().getZExtValue())); + } + + elidedAttrs.push_back("source_memory_access"); + + if (auto attr = op.getAttr("source_alignment")) { + operands.push_back(static_cast( + attr.cast().getValue().getZExtValue())); + } + + elidedAttrs.push_back("source_alignment"); + emitDebugLine(functionBody, op.getLoc()); + encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); + + return success(); +} + // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and // various Serializer::processOp<...>() specializations. #define GET_SERIALIZATION_FNS diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir @@ -93,6 +93,18 @@ // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32 spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"] : f32 + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"], ["Volatile"] : f32 + spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"], ["Volatile"] : f32 + + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Volatile"] : f32 + spv.CopyMemory "Function" %0, "Function" %1 ["Aligned", 4], ["Volatile"] : f32 + + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"], ["Aligned", 4] : f32 + spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"], ["Aligned", 4] : f32 + + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 8], ["Aligned", 4] : f32 + spv.CopyMemory "Function" %0, "Function" %1 ["Aligned", 8], ["Aligned", 4] : f32 + spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -1247,7 +1247,7 @@ // ----- -func @copy_memory_incompatible_ptrs() -> () { +func @copy_memory_incompatible_ptrs() { %0 = spv.Variable : !spv.ptr %1 = spv.Variable : !spv.ptr // expected-error @+1 {{both operands must be pointers to the same type}} @@ -1257,7 +1257,7 @@ // ----- -func @copy_memory_invalid_maa() -> () { +func @copy_memory_invalid_maa() { %0 = spv.Variable : !spv.ptr %1 = spv.Variable : !spv.ptr // expected-error @+1 {{missing alignment value}} @@ -1267,7 +1267,27 @@ // ----- -func @copy_memory_print_maa() -> () { +func @copy_memory_invalid_source_maa() { + %0 = spv.Variable : !spv.ptr + %1 = spv.Variable : !spv.ptr + // expected-error @+1 {{invalid alignment specification with non-aligned memory access specification}} + "spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + spv.Return +} + +// ----- + +func @copy_memory_invalid_source_maa2() { + %0 = spv.Variable : !spv.ptr + %1 = spv.Variable : !spv.ptr + // expected-error @+1 {{missing alignment value}} + "spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + spv.Return +} + +// ----- + +func @copy_memory_print_maa() { %0 = spv.Variable : !spv.ptr %1 = spv.Variable : !spv.ptr @@ -1277,5 +1297,11 @@ // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32 "spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Volatile"] : f32 + "spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + + // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Aligned", 8] : f32 + "spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () + spv.Return }