diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -198,7 +198,7 @@ ``` copy-memory-op ::= `spv.CopyMemory ` storage-class ssa-use storage-class ssa-use - (`[` memory-access `]` (`, [` memory-access `]`)?)? + (`[` memory-access `]`)? ` : ` spirv-element-type ``` @@ -215,16 +215,12 @@ SPV_AnyPtr:$target, SPV_AnyPtr:$source, OptionalAttr:$memory_access, - OptionalAttr:$alignment, - OptionalAttr:$source_memory_access, - OptionalAttr:$source_alignment + OptionalAttr:$alignment ); let results = (outs); let verifier = [{ return verifyCopyMemory(*this); }]; - - let autogenSerialization = 0; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -28,11 +28,7 @@ using namespace mlir; // TODO(antiagainst): generate these strings using ODS. -static constexpr const char kMemoryAccessAttrName[] = "memory_access"; -static constexpr const char kSourceMemoryAccessAttrName[] = - "source_memory_access"; static constexpr const char kAlignmentAttrName[] = "alignment"; -static constexpr const char kSourceAlignmentAttrName[] = "source_alignment"; static constexpr const char kBranchWeightAttrName[] = "branch_weights"; static constexpr const char kCallee[] = "callee"; static constexpr const char kClusterSize[] = "cluster_size"; @@ -161,8 +157,6 @@ return success(); } -template static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state) { // Parse an optional list of attributes staring with '[' @@ -172,7 +166,7 @@ } spirv::MemoryAccess memoryAccessAttr; - if (parseEnumStrAttr(memoryAccessAttr, parser, state, memoryAccessAttrName)) { + if (parseEnumStrAttr(memoryAccessAttr, parser, state)) { return failure(); } @@ -181,7 +175,7 @@ Attribute alignmentAttr; Type i32Type = parser.getBuilder().getIntegerType(32); if (parser.parseComma() || - parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName, + parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName, state.attributes)) { return failure(); } @@ -189,33 +183,19 @@ return parser.parseRSquare(); } -template -static void printMemoryAccessAttribute( - MemoryOpTy memoryOp, OpAsmPrinter &printer, - SmallVectorImpl &elidedAttrs, - Optional memoryAccessAtrrValue = None, - Optional alignmentAttrValue = None) { +template +static void +printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, + SmallVectorImpl &elidedAttrs) { // Print optional memory access attribute. - if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue - : memoryOp.memory_access())) { - elidedAttrs.push_back(memoryAccessAttrName); - - if (!first) { - printer << ", "; - } - + if (auto memAccess = memoryOp.memory_access()) { + elidedAttrs.push_back(spirv::attributeName()); printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; - if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { - // Print integer alignment attribute. - if (auto alignment = (alignmentAttrValue ? alignmentAttrValue - : memoryOp.alignment())) { - elidedAttrs.push_back(alignmentAttrName); - printer << ", " << alignment; - } + // Print integer alignment attribute. + if (auto alignment = memoryOp.alignment()) { + elidedAttrs.push_back(kAlignmentAttrName); + printer << ", " << alignment; } printer << "]"; } @@ -263,19 +243,17 @@ return success(); } -template +template static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) { // ODS checks for attributes values. Just need to verify that if the // memory-access attribute is Aligned, then the alignment attribute must be // present. auto *op = memoryOp.getOperation(); - auto memAccessAttr = op->getAttr(memoryAccessAttrName); + auto memAccessAttr = op->getAttr(spirv::attributeName()); if (!memAccessAttr) { // Alignment attribute shouldn't be present if memory access attribute is // not present. - if (op->getAttr(alignmentAttrName)) { + if (op->getAttr(kAlignmentAttrName)) { return memoryOp.emitOpError( "invalid alignment specification without aligned memory access " "specification"); @@ -292,11 +270,11 @@ } if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { - if (!op->getAttr(alignmentAttrName)) { + if (!op->getAttr(kAlignmentAttrName)) { return memoryOp.emitOpError("missing alignment value"); } } else { - if (op->getAttr(alignmentAttrName)) { + if (op->getAttr(kAlignmentAttrName)) { return memoryOp.emitOpError( "invalid alignment specification with non-aligned memory access " "specification"); @@ -2861,10 +2839,6 @@ SmallVector elidedAttrs; printMemoryAccessAttribute(copyMemory, printer, elidedAttrs); - printMemoryAccessAttribute( - copyMemory, printer, elidedAttrs, copyMemory.source_memory_access(), - copyMemory.source_alignment()); printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); @@ -2887,23 +2861,9 @@ parser.parseOperand(targetPtrInfo) || parser.parseComma() || parseEnumStrAttr(sourceStorageClass, parser) || parser.parseOperand(sourcePtrInfo) || - parseMemoryAccessAttributes(parser, state)) { - return failure(); - } - - if (!parser.parseOptionalComma()) { - // Parse 2nd memory access attributes. - if (parseMemoryAccessAttributes(parser, state)) { - return failure(); - } - } - - if (parser.parseColon() || parser.parseType(elementType)) { - return failure(); - } - - if (parser.parseOptionalAttrDict(state.attributes)) { + parseMemoryAccessAttributes(parser, state) || + parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() || + parser.parseType(elementType)) { return failure(); } @@ -2930,21 +2890,7 @@ "both operands must be pointers to the same type"); } - if (failed(verifyMemoryAccessAttribute(copyMemory))) { - return failure(); - } - - // TODO (ergawy): According to the spec: - // - // If two masks are present, the first applies to Target and cannot include - // MakePointerVisible, and the second applies to Source and cannot include - // MakePointerAvailable. - // - // Add such verification here. - - return verifyMemoryAccessAttribute(copyMemory); + return verifyMemoryAccessAttribute(copyMemory); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -400,8 +400,7 @@ /// Method to deserialize an operation in the SPIR-V dialect that is a mirror /// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode /// == 1 and autogenSerialization == 1 in ODS. - template - LogicalResult processOp(ArrayRef words) { + template LogicalResult processOp(ArrayRef words) { return emitError(unknownLoc, "unsupported deserialization for ") << OpTy::getOperationName() << " op"; } @@ -1567,8 +1566,8 @@ return success(); } - return emitError(unknownLoc, "unsupported OpConstantNull type: ") - << resultType; + return emitError(unknownLoc, "unsupported OpConstantNull type: ") + << resultType; } //===----------------------------------------------------------------------===// @@ -2510,76 +2509,6 @@ return success(); } -template <> -LogicalResult -Deserializer::processOp(ArrayRef words) { - SmallVector resultTypes; - size_t wordIndex = 0; - SmallVector operands; - SmallVector attributes; - - if (wordIndex < words.size()) { - auto arg = getValue(words[wordIndex]); - - if (!arg) { - return emitError(unknownLoc, "unknown result : ") - << words[wordIndex]; - } - - operands.push_back(arg); - wordIndex++; - } - - if (wordIndex < words.size()) { - auto arg = getValue(words[wordIndex]); - - if (!arg) { - return emitError(unknownLoc, "unknown result : ") - << words[wordIndex]; - } - - operands.push_back(arg); - wordIndex++; - } - - bool isAlignedAttr = false; - - if (wordIndex < words.size()) { - auto attrValue = words[wordIndex++]; - attributes.push_back(opBuilder.getNamedAttr( - "memory_access", opBuilder.getI32IntegerAttr(attrValue))); - isAlignedAttr = (attrValue == 2); - } - - if (isAlignedAttr && wordIndex < words.size()) { - attributes.push_back(opBuilder.getNamedAttr( - "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); - } - - if (wordIndex < words.size()) { - attributes.push_back(opBuilder.getNamedAttr( - "source_memory_access", - opBuilder.getI32IntegerAttr(words[wordIndex++]))); - } - - if (wordIndex < words.size()) { - attributes.push_back(opBuilder.getNamedAttr( - "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); - } - - if (wordIndex != words.size()) { - return emitError(unknownLoc, - "found more operands than expected when deserializing " - "spirv::CopyMemoryOp, only ") - << wordIndex << " of " << words.size() << " processed"; - } - - Location loc = createFileLineColLoc(opBuilder); - opBuilder.create(loc, resultTypes, operands, attributes); - - return success(); -} - // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and // various Deserializer::processOp<...>() specializations. #define GET_DESERIALIZATION_FNS diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -364,8 +364,7 @@ /// Method to serialize an operation in the SPIR-V dialect that is a mirror of /// an instruction in the SPIR-V spec. This is auto generated if hasOpcode == /// 1 and autogenSerialization == 1 in ODS. - template - LogicalResult processOp(OpTy op) { + template LogicalResult processOp(OpTy op) { return op.emitError("unsupported op serialization"); } @@ -1905,51 +1904,6 @@ operands); } -template <> -LogicalResult -Serializer::processOp(spirv::CopyMemoryOp op) { - SmallVector operands; - SmallVector elidedAttrs; - - for (Value operand : op.getOperation()->getOperands()) { - auto id = getValueID(operand); - assert(id && "use before def!"); - operands.push_back(id); - } - - if (auto attr = op.getAttr("memory_access")) { - operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); - } - - elidedAttrs.push_back("memory_access"); - - if (auto attr = op.getAttr("alignment")) { - operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); - } - - elidedAttrs.push_back("alignment"); - - if (auto attr = op.getAttr("source_memory_access")) { - operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); - } - - elidedAttrs.push_back("source_memory_access"); - - if (auto attr = op.getAttr("source_alignment")) { - operands.push_back(static_cast( - attr.cast().getValue().getZExtValue())); - } - - elidedAttrs.push_back("source_alignment"); - emitDebugLine(functionBody, op.getLoc()); - encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); - - return success(); -} - // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and // various Serializer::processOp<...>() specializations. #define GET_SERIALIZATION_FNS diff --git a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/memory-ops.mlir @@ -93,18 +93,6 @@ // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32 spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"] : f32 - // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"], ["Volatile"] : f32 - spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"], ["Volatile"] : f32 - - // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Volatile"] : f32 - spv.CopyMemory "Function" %0, "Function" %1 ["Aligned", 4], ["Volatile"] : f32 - - // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"], ["Aligned", 4] : f32 - spv.CopyMemory "Function" %0, "Function" %1 ["Volatile"], ["Aligned", 4] : f32 - - // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 8], ["Aligned", 4] : f32 - spv.CopyMemory "Function" %0, "Function" %1 ["Aligned", 8], ["Aligned", 4] : f32 - spv.Return } } diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -1247,7 +1247,7 @@ // ----- -func @copy_memory_incompatible_ptrs() { +func @copy_memory_incompatible_ptrs() -> () { %0 = spv.Variable : !spv.ptr %1 = spv.Variable : !spv.ptr // expected-error @+1 {{both operands must be pointers to the same type}} @@ -1257,7 +1257,7 @@ // ----- -func @copy_memory_invalid_maa() { +func @copy_memory_invalid_maa() -> () { %0 = spv.Variable : !spv.ptr %1 = spv.Variable : !spv.ptr // expected-error @+1 {{missing alignment value}} @@ -1267,27 +1267,7 @@ // ----- -func @copy_memory_invalid_source_maa() { - %0 = spv.Variable : !spv.ptr - %1 = spv.Variable : !spv.ptr - // expected-error @+1 {{invalid alignment specification with non-aligned memory access specification}} - "spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () - spv.Return -} - -// ----- - -func @copy_memory_invalid_source_maa2() { - %0 = spv.Variable : !spv.ptr - %1 = spv.Variable : !spv.ptr - // expected-error @+1 {{missing alignment value}} - "spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () - spv.Return -} - -// ----- - -func @copy_memory_print_maa() { +func @copy_memory_print_maa() -> () { %0 = spv.Variable : !spv.ptr %1 = spv.Variable : !spv.ptr @@ -1297,11 +1277,5 @@ // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32 "spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () - // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Volatile"] : f32 - "spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () - - // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Aligned", 8] : f32 - "spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr, !spv.ptr) -> () - spv.Return }