diff --git a/mlir/examples/standalone/include/Standalone/StandaloneDialect.h b/mlir/examples/standalone/include/Standalone/StandaloneDialect.h --- a/mlir/examples/standalone/include/Standalone/StandaloneDialect.h +++ b/mlir/examples/standalone/include/Standalone/StandaloneDialect.h @@ -9,6 +9,7 @@ #ifndef STANDALONE_STANDALONEDIALECT_H #define STANDALONE_STANDALONEDIALECT_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Dialect.h" #include "Standalone/StandaloneOpsDialect.h.inc" diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -74,6 +74,10 @@ /// Read a reference to the given attribute. virtual LogicalResult readAttribute(Attribute &result) = 0; + /// Read an optional reference to the given attribute. Returns success even if + /// the Attribute isn't present. + virtual LogicalResult readOptionalAttribute(Attribute &attr) = 0; + template LogicalResult readAttributes(SmallVectorImpl &attrs) { return readList(attrs, [this](T &attr) { return readAttribute(attr); }); @@ -88,6 +92,18 @@ return emitError() << "expected " << llvm::getTypeName() << ", but got: " << baseResult; } + template + LogicalResult readOptionalAttribute(T &result) { + Attribute baseResult; + if (failed(readOptionalAttribute(baseResult))) + return failure(); + if (!baseResult) + return success(); + if ((result = dyn_cast(baseResult))) + return success(); + return emitError() << "expected " << llvm::getTypeName() + << ", but got: " << baseResult; + } /// Read a reference to the given type. virtual LogicalResult readType(Type &result) = 0; @@ -179,6 +195,7 @@ /// Write a reference to the given attribute. virtual void writeAttribute(Attribute attr) = 0; + virtual void writeOptionalAttribute(Attribute attr) = 0; template void writeAttributes(ArrayRef attrs) { writeList(attrs, [this](T attr) { writeAttribute(attr); }); diff --git a/mlir/include/mlir/Bytecode/BytecodeOpInterface.h b/mlir/include/mlir/Bytecode/BytecodeOpInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Bytecode/BytecodeOpInterface.h @@ -0,0 +1,27 @@ +//===- BytecodeOpInterface.h - Bytecode interface for MLIR Op ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the BytecodeOpInterface defined in +// `BytecodeOpInterface.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BYTECODE_BYTECODEOPINTERFACE_H +#define MLIR_BYTECODE_BYTECODEOPINTERFACE_H + +#include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Bytecode/BytecodeReader.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LogicalResult.h" + +/// Include the generated interface declarations. +#include "mlir/Bytecode/BytecodeOpInterface.h.inc" + +#endif // MLIR_BYTECODE_BYTECODEOPINTERFACE_H diff --git a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td @@ -0,0 +1,43 @@ +//===- BytecodeOpInterface.td - Bytecode OpInterface -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains an interface for operation interactions with the bytecode +// serialization/deserialization, in particular for properties. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BYTECODE_BYTECODEOPINTERFACES +#define MLIR_BYTECODE_BYTECODEOPINTERFACES + +include "mlir/IR/OpBase.td" + +// `BytecodeOpInterface` +def BytecodeOpInterface : OpInterface<"BytecodeOpInterface"> { + let description = [{ + This interface allows operation to control the serialization of their + properties. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + StaticInterfaceMethod<[{ + Read the properties for this operation from the bytecode and populate the state. + }], + "LogicalResult", "readProperties", (ins + "::mlir::DialectBytecodeReader &":$reader, + "::mlir::OperationState &":$state) + >, + InterfaceMethod<[{ + Write the properties for this operation to the bytecode. + }], + "void", "writeProperties", (ins "::mlir::DialectBytecodeWriter &":$writer) + >, + ]; +} + +#endif // MLIR_BYTECODE_BYTECODEOPINTERFACES diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h --- a/mlir/include/mlir/Bytecode/BytecodeWriter.h +++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h @@ -46,6 +46,9 @@ /// is returned by bytecode writer entry point. void setDesiredBytecodeVersion(int64_t bytecodeVersion); + /// Get the set desired bytecode version to emit. + int64_t getDesiredBytecodeVersion() const; + //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Bytecode/CMakeLists.txt b/mlir/include/mlir/Bytecode/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Bytecode/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_interface(BytecodeOpInterface) diff --git a/mlir/include/mlir/Bytecode/Encoding.h b/mlir/include/mlir/Bytecode/Encoding.h --- a/mlir/include/mlir/Bytecode/Encoding.h +++ b/mlir/include/mlir/Bytecode/Encoding.h @@ -29,7 +29,7 @@ kMinSupportedVersion = 0, /// The current bytecode version. - kVersion = 4, + kVersion = 5, /// An arbitrary value used to fill alignment padding. kAlignmentByte = 0xCB, @@ -69,8 +69,11 @@ /// This section contains the versions of each dialect. kDialectVersions = 7, + /// This section contains the properties for the operations. + kProperties = 8, + /// The total number of section types. - kNumSections = 8, + kNumSections = 9, }; } // namespace Section @@ -90,6 +93,7 @@ kHasSuccessors = 0b00001000, kHasInlineRegions = 0b00010000, kHasUseListOrders = 0b00100000, + kHasProperties = 0b01000000, // clang-format on }; } // namespace OpEncodingMask diff --git a/mlir/include/mlir/CMakeLists.txt b/mlir/include/mlir/CMakeLists.txt --- a/mlir/include/mlir/CMakeLists.txt +++ b/mlir/include/mlir/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Bytecode) add_subdirectory(Conversion) add_subdirectory(Dialect) add_subdirectory(IR) diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_AMDGPU_IR_AMDGPUDIALECT_H_ #define MLIR_DIALECT_AMDGPU_IR_AMDGPUDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h --- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h +++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_AMX_AMXDIALECT_H_ #define MLIR_DIALECT_AMX_AMXDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h --- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h +++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_AFFINE_TRANSFORMOPS_AFFINETRANSFORMOPS_H #define MLIR_DIALECT_AFFINE_TRANSFORMOPS_AFFINETRANSFORMOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_ARITH_IR_ARITH_H_ #define MLIR_DIALECT_ARITH_IR_ARITH_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h --- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h +++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_ #define MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H #define MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h --- a/mlir/include/mlir/Dialect/Async/IR/Async.h +++ b/mlir/include/mlir/Dialect/Async/IR/Async.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_ASYNC_IR_ASYNC_H #define MLIR_DIALECT_ASYNC_IR_ASYNC_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Async/IR/AsyncTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATION_H_ #define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATION_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Interfaces/CopyOpInterface.h" diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" diff --git a/mlir/include/mlir/Dialect/Complex/IR/Complex.h b/mlir/include/mlir/Dialect/Complex/IR/Complex.h --- a/mlir/include/mlir/Dialect/Complex/IR/Complex.h +++ b/mlir/include/mlir/Dialect/Complex/IR/Complex.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_ #define MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/InferTypeOpInterface.h" diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlow.h b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlow.h --- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlow.h +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlow.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H #define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Dialect.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.h.inc" diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h --- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H #define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_EMITC_IR_EMITC_H #define MLIR_DIALECT_EMITC_IR_EMITC_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.h b/mlir/include/mlir/Dialect/Func/IR/FuncOps.h --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.h +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_FUNC_IR_OPS_H #define MLIR_DIALECT_FUNC_IR_OPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_GPU_IR_GPUDIALECT_H #define MLIR_DIALECT_GPU_IR_GPUDIALECT_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h b/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h @@ -13,11 +13,13 @@ #ifndef MLIR_DIALECT_IRDL_IR_IRDL_H_ #define MLIR_DIALECT_IRDL_IR_IRDL_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h" #include "mlir/Dialect/IRDL/IR/IRDLTraits.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" + #include // Forward declaration. diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_INDEX_IR_INDEXOPS_H #define MLIR_DIALECT_INDEX_IR_INDEXOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Index/IR/IndexAttrs.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ #define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ #define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h @@ -22,6 +22,7 @@ #ifndef MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_ #define MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h --- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_LINALG_IR_LINALG_H #define MLIR_DIALECT_LINALG_IR_LINALG_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h @@ -8,6 +8,7 @@ #ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_ #define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.h" #include "mlir/Dialect/MLProgram/IR/MLProgramTypes.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/Math/IR/Math.h b/mlir/include/mlir/Dialect/Math/IR/Math.h --- a/mlir/include/mlir/Dialect/Math/IR/Math.h +++ b/mlir/include/mlir/Dialect/Math/IR/Math.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_MATH_IR_MATH_H_ #define MLIR_DIALECT_MATH_IR_MATH_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_MEMREF_IR_MEMREF_H_ #define MLIR_DIALECT_MEMREF_IR_MEMREF_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/Dialect.h" @@ -21,6 +22,7 @@ #include "mlir/Interfaces/ShapedOpInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" + #include namespace mlir { diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H #define MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_NVGPU_NVGPUDIALECT_H_ #define MLIR_DIALECT_NVGPU_NVGPUDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h --- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -18,6 +18,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.h.inc" #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.h.inc" #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.h.inc" diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.h b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.h --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.h +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_PDL_IR_PDLOPS_H_ #define MLIR_DIALECT_PDL_IR_PDLOPS_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_ #define MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/FunctionInterfaces.h" diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H #define MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_ #define MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_SHAPE_IR_SHAPE_H #define MLIR_DIALECT_SHAPE_IR_SHAPE_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/FunctionInterfaces.h" diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_ #define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_TENSOR_IR_TENSOR_H_ #define MLIR_DIALECT_TENSOR_IR_TENSOR_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_TOSA_IR_TOSAOPS_H #define MLIR_DIALECT_TOSA_IR_TOSAOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/InferTypeOpInterface.h" diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h --- a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h +++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H #define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" #include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h" #include "mlir/IR/AffineMap.h" diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h --- a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h +++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_ #define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -238,6 +238,25 @@ llvm::hash_value($_storage); }]; + // The call expression to emit the storage type to bytecode. + // + // Format: + // - `$_storage` is the storage type value. + // - `$_writer` is a `DialectBytecodeWriter`. + code writeToMlirBytecode = [{ + writeToMlirBytecode($_writer, $_storage) + }]; + + // The call expression to read the storage type from bytecode. + // + // Format: + // - `$_storage` is the storage type value. + // - `$_reader` is a `DialectBytecodeReader`. + code readFromMlirBytecode = [{ + if (failed(readFromMlirBytecode($_reader, $_storage))) + return failure(); + }]; + // Default value for the property. string defaultValue = ?; } @@ -1156,6 +1175,20 @@ //===----------------------------------------------------------------------===// // Primitive property kinds +// Any kind of integer stored as properties. +class IntProperty : + Property { + code writeToMlirBytecode = [{ + $_writer.writeVarInt($_storage); + }]; + code readFromMlirBytecode = [{ + uint64_t val; + if (failed($_reader.readVarInt(val))) + return ::mlir::failure(); + $_storage = val; + }]; +} + class ArrayProperty : Property { let interfaceType = "::llvm::ArrayRef<" # storageTypeParam # ">"; diff --git a/mlir/include/mlir/TableGen/Property.h b/mlir/include/mlir/TableGen/Property.h --- a/mlir/include/mlir/TableGen/Property.h +++ b/mlir/include/mlir/TableGen/Property.h @@ -58,6 +58,14 @@ // in the provided interface type and assign it to the storage. StringRef getConvertFromAttributeCall() const; + // Returns the method call which reads this property from + // bytecode and assign it to the storage. + StringRef getReadFromMlirBytecodeCall() const; + + // Returns the method call which write this property's + // to the the bytecode. + StringRef getWriteToMlirBytecodeCall() const; + // Returns the code to compute the hash for this property. StringRef getHashPropertyCall() const; diff --git a/mlir/lib/Bytecode/BytecodeOpInterface.cpp b/mlir/lib/Bytecode/BytecodeOpInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/BytecodeOpInterface.cpp @@ -0,0 +1,17 @@ +//===- BytecodeOpInterface.cpp - Bytecode Op Interfaces -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Bytecode/BytecodeOpInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// BytecodeOpInterface +//===----------------------------------------------------------------------===// + +#include "mlir/Bytecode/BytecodeOpInterface.cpp.inc" diff --git a/mlir/lib/Bytecode/CMakeLists.txt b/mlir/lib/Bytecode/CMakeLists.txt --- a/mlir/lib/Bytecode/CMakeLists.txt +++ b/mlir/lib/Bytecode/CMakeLists.txt @@ -1,2 +1,13 @@ add_subdirectory(Reader) add_subdirectory(Writer) + +add_mlir_library(MLIRBytecodeOpInterface + BytecodeOpInterface.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Bytecode + + LINK_LIBS PUBLIC + MLIRIR + MLIRSupport + ) diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -11,6 +11,7 @@ #include "mlir/Bytecode/BytecodeReader.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Bytecode/Encoding.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" @@ -20,6 +21,7 @@ #include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallString.h" @@ -28,6 +30,7 @@ #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SourceMgr.h" +#include #include #include #include @@ -56,13 +59,15 @@ return "ResourceOffset (6)"; case bytecode::Section::kDialectVersions: return "DialectVersions (7)"; + case bytecode::Section::kProperties: + return "Properties (8)"; default: return ("Unknown (" + Twine(static_cast(sectionID)) + ")").str(); } } /// Returns true if the given top-level section ID is optional. -static bool isSectionOptional(bytecode::Section::ID sectionID) { +static bool isSectionOptional(bytecode::Section::ID sectionID, int version) { switch (sectionID) { case bytecode::Section::kString: case bytecode::Section::kDialect: @@ -74,6 +79,8 @@ case bytecode::Section::kResourceOffset: case bytecode::Section::kDialectVersions: return true; + case bytecode::Section::kProperties: + return version < 5; default: llvm_unreachable("unknown section ID"); } @@ -362,6 +369,17 @@ return parseEntry(reader, strings, result, "string"); } + /// Parse a shared string from the string section. The shared string is + /// encoded using an index to a corresponding string in the string section. + /// This variant parses a flag compressed with the index. + LogicalResult parseStringWithFlag(EncodingReader &reader, StringRef &result, + bool &flag) { + uint64_t entryIdx; + if (failed(reader.parseVarIntWithFlag(entryIdx, flag))) + return failure(); + return parseStringAtIndex(reader, entryIdx, result); + } + /// Parse a shared string from the string section. The shared string is /// encoded using an index to a corresponding string in the string section. LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, @@ -459,8 +477,9 @@ /// This struct represents an operation name entry within the bytecode. struct BytecodeOperationName { - BytecodeOperationName(BytecodeDialect *dialect, StringRef name) - : dialect(dialect), name(name) {} + BytecodeOperationName(BytecodeDialect *dialect, StringRef name, + std::optional wasRegistered) + : dialect(dialect), name(name), wasRegistered(wasRegistered) {} /// The loaded operation name, or std::nullopt if it hasn't been processed /// yet. @@ -471,6 +490,10 @@ /// The name of the operation, without the dialect prefix. StringRef name; + + /// Whether this operation was registered when the bytecode was produced. + /// This flag is populated when bytecode version >=5. + std::optional wasRegistered; }; } // namespace @@ -791,6 +814,18 @@ result = resolveAttribute(attrIdx); return success(!!result); } + LogicalResult parseOptionalAttribute(EncodingReader &reader, + Attribute &result) { + uint64_t attrIdx; + bool flag; + if (failed(reader.parseVarIntWithFlag(attrIdx, flag))) + return failure(); + if (!flag) + return success(); + result = resolveAttribute(attrIdx); + return success(!!result); + } + LogicalResult parseType(EncodingReader &reader, Type &result) { uint64_t typeIdx; if (failed(reader.parseVarInt(typeIdx))) @@ -870,7 +905,9 @@ LogicalResult readAttribute(Attribute &result) override { return attrTypeReader.parseAttribute(reader, result); } - + LogicalResult readOptionalAttribute(Attribute &result) override { + return attrTypeReader.parseOptionalAttribute(reader, result); + } LogicalResult readType(Type &result) override { return attrTypeReader.parseType(reader, result); } @@ -957,6 +994,87 @@ ResourceSectionReader &resourceReader; EncodingReader &reader; }; + +/// Wraps the properties section and handles reading properties out of it. +class PropertiesSectionReader { +public: + /// Initialize the properties section reader with the given section data. + LogicalResult initialize(Location fileLoc, ArrayRef sectionData) { + if (sectionData.empty()) + return success(); + EncodingReader propReader(sectionData, fileLoc); + size_t count; + if (failed(propReader.parseVarInt(count))) + return failure(); + // Parse the raw properties buffer. + if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers))) + return failure(); + + EncodingReader offsetsReader(propertiesBuffers, fileLoc); + offsetTable.reserve(count); + for (auto idx : llvm::seq(0, count)) { + (void)idx; + offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size()); + ArrayRef rawProperties; + size_t dataSize; + if (failed(offsetsReader.parseVarInt(dataSize)) || + failed(offsetsReader.parseBytes(dataSize, rawProperties))) + return failure(); + } + if (!offsetsReader.empty()) + return offsetsReader.emitError() + << "Broken properties section: didn't exhaust the offsets table"; + return success(); + } + + LogicalResult read(Location fileLoc, DialectReader &dialectReader, + OperationName *opName, OperationState &opState) { + uint64_t propertiesIdx; + if (failed(dialectReader.readVarInt(propertiesIdx))) + return failure(); + if (propertiesIdx >= offsetTable.size()) + return dialectReader.emitError("Properties idx out-of-bound for ") + << opName->getStringRef(); + size_t propertiesOffset = offsetTable[propertiesIdx]; + if (propertiesIdx >= propertiesBuffers.size()) + return dialectReader.emitError("Properties offset out-of-bound for ") + << opName->getStringRef(); + + // Acquire the sub-buffer that represent the requested properties. + ArrayRef rawProperties; + { + // "Seek" to the requested offset by getting a new reader with the right + // sub-buffer. + EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset), + fileLoc); + // Properties are stored as a sequence of {size + raw_data}. + if (failed( + dialectReader.withEncodingReader(reader).readBlob(rawProperties))) + return failure(); + } + // Setup a new reader to read from the `rawProperties` sub-buffer. + EncodingReader reader( + StringRef(rawProperties.begin(), rawProperties.size()), fileLoc); + DialectReader propReader = dialectReader.withEncodingReader(reader); + + auto *iface = opName->getInterface(); + if (iface) + return iface->readProperties(propReader, opState); + if (opName->isRegistered()) + return propReader.emitError( + "has properties but missing BytecodeOpInterface for ") + << opName->getStringRef(); + // Unregistered op are storing properties as an attribute. + return propReader.readAttribute(opState.propertiesAttr); + } + +private: + /// The properties buffer referenced within the bytecode file. + ArrayRef propertiesBuffers; + + /// Table of offset in the buffer above. + SmallVector offsetTable; +}; } // namespace LogicalResult @@ -1194,7 +1312,9 @@ lazyLoadableOps.erase(it->getSecond()); lazyLoadableOpsMap.erase(it); auto result = parseRegions(regionStack, regionStack.back()); - assert(regionStack.empty()); + assert((regionStack.empty() || failed(result)) && + "broken invariant: regionStack should be empty when parseRegions " + "succeeds"); return result; } @@ -1209,8 +1329,11 @@ LogicalResult parseDialectSection(ArrayRef sectionData); - /// Parse an operation name reference using the given reader. - FailureOr parseOpName(EncodingReader &reader); + /// Parse an operation name reference using the given reader, and set the + /// `wasRegistered` flag that indicates if the bytecode was produced by a + /// context where opName was registered. + FailureOr parseOpName(EncodingReader &reader, + std::optional &wasRegistered); //===--------------------------------------------------------------------===// // Attribute/Type Section @@ -1398,6 +1521,9 @@ /// The table of strings referenced within the bytecode file. StringSectionReader stringReader; + /// The table of properties referenced by the operation in the bytecode file. + PropertiesSectionReader propertiesReader; + /// The current set of available IR value scopes. std::vector valueScopes; @@ -1466,7 +1592,7 @@ // Check that all of the required sections were found. for (int i = 0; i < bytecode::Section::kNumSections; ++i) { bytecode::Section::ID sectionID = static_cast(i); - if (!sectionDatas[i] && !isSectionOptional(sectionID)) { + if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) { return reader.emitError("missing data for top-level section: ", ::toString(sectionID)); } @@ -1477,6 +1603,12 @@ fileLoc, *sectionDatas[bytecode::Section::kString]))) return failure(); + // Process the properties section. + if (sectionDatas[bytecode::Section::kProperties] && + failed(propertiesReader.initialize( + fileLoc, *sectionDatas[bytecode::Section::kProperties]))) + return failure(); + // Process the dialect section. if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) return failure(); @@ -1598,9 +1730,20 @@ // Parse the operation names, which are grouped by dialect. auto parseOpName = [&](BytecodeDialect *dialect) { StringRef opName; - if (failed(stringReader.parseString(sectionReader, opName))) - return failure(); - opNames.emplace_back(dialect, opName); + std::optional wasRegistered; + // Prior to version 5, the information about wheter an op was registered or + // not wasn't encoded. + if (version < 5) { + if (failed(stringReader.parseString(sectionReader, opName))) + return failure(); + } else { + bool wasRegisteredFlag; + if (failed(stringReader.parseStringWithFlag(sectionReader, opName, + wasRegisteredFlag))) + return failure(); + wasRegistered = wasRegisteredFlag; + } + opNames.emplace_back(dialect, opName, wasRegistered); return success(); }; // Avoid re-allocation in bytecode version > 3 where the number of ops are @@ -1618,11 +1761,12 @@ } FailureOr -BytecodeReader::Impl::parseOpName(EncodingReader &reader) { +BytecodeReader::Impl::parseOpName(EncodingReader &reader, + std::optional &wasRegistered) { BytecodeOperationName *opName = nullptr; if (failed(parseEntry(reader, opNames, opName, "operation name"))) return failure(); - + wasRegistered = opName->wasRegistered; // Check to see if this operation name has already been resolved. If we // haven't, load the dialect and build the operation name. if (!opName->opName) { @@ -1994,7 +2138,8 @@ RegionReadState &readState, bool &isIsolatedFromAbove) { // Parse the name of the operation. - FailureOr opName = parseOpName(reader); + std::optional wasRegistered; + FailureOr opName = parseOpName(reader, wasRegistered); if (failed(opName)) return failure(); @@ -2021,6 +2166,31 @@ opState.attributes = dictAttr; } + if (opMask & bytecode::OpEncodingMask::kHasProperties) { + // kHasProperties wasn't emitted in older bytecode, we should never get + // there without also having the `wasRegistered` flag available. + if (!wasRegistered) + return emitError(fileLoc, + "Unexpected missing `wasRegistered` opname flag at " + "bytecode version ") + << version << " with properties."; + // When an operation is emitted without being registered, the properties are + // stored as an attribute. Otherwise the op must implement the bytecode + // interface and control the serialization. + if (wasRegistered) { + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + reader); + if (failed( + propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) + return failure(); + } else { + // If the operation wasn't registered when it was emitted, the properties + // was serialized as an attribute. + if (failed(parseAttribute(reader, opState.propertiesAttr))) + return failure(); + } + } + /// Parse the results of the operation. if (opMask & bytecode::OpEncodingMask::kHasResults) { uint64_t numResults; diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -9,11 +9,23 @@ #include "mlir/Bytecode/BytecodeWriter.h" #include "IRNumbering.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Bytecode/Encoding.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/CachedHashString.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include +#include #define DEBUG_TYPE "mlir-bytecode-writer" @@ -58,6 +70,10 @@ std::min(bytecodeVersion, bytecode::kVersion); } +int64_t BytecodeWriterConfig::getDesiredBytecodeVersion() const { + return impl->bytecodeVersion; +} + //===----------------------------------------------------------------------===// // EncodingEmitter //===----------------------------------------------------------------------===// @@ -318,6 +334,14 @@ void writeAttribute(Attribute attr) override { emitter.emitVarInt(numberingState.getNumber(attr)); } + void writeOptionalAttribute(Attribute attr) override { + if (!attr) { + emitter.emitVarInt(0); + return; + } + emitter.emitVarIntWithFlag(numberingState.getNumber(attr), true); + } + void writeType(Type type) override { emitter.emitVarInt(numberingState.getNumber(type)); } @@ -382,6 +406,105 @@ StringSectionBuilder &stringSection; }; +namespace { +class PropertiesSectionBuilder { +public: + PropertiesSectionBuilder(IRNumberingState &numberingState, + StringSectionBuilder &stringSection, + const BytecodeWriterConfig::Impl &config) + : numberingState(numberingState), stringSection(stringSection), + config(config) {} + + /// Emit the op properties in the properties section and return the index of + /// the properties within the section. Return -1 if no properties was emitted. + std::optional emit(Operation *op) { + EncodingEmitter propertiesEmitter; + if (!op->getPropertiesStorageSize()) + return std::nullopt; + if (!op->isRegistered()) { + // Unregistered op are storing properties as an optional attribute. + Attribute prop = *op->getPropertiesStorage().as(); + if (!prop) + return std::nullopt; + EncodingEmitter sizeEmitter; + sizeEmitter.emitVarInt(numberingState.getNumber(prop)); + scratch.clear(); + llvm::raw_svector_ostream os(scratch); + sizeEmitter.writeTo(os); + return emit(scratch); + } + + EncodingEmitter emitter; + DialectWriter propertiesWriter(config.bytecodeVersion, emitter, + numberingState, stringSection); + auto iface = cast(op); + iface.writeProperties(propertiesWriter); + scratch.clear(); + llvm::raw_svector_ostream os(scratch); + emitter.writeTo(os); + return emit(scratch); + } + + /// Write the current set of properties to the given emitter. + void write(EncodingEmitter &emitter) { + emitter.emitVarInt(propertiesStorage.size()); + if (propertiesStorage.empty()) + return; + for (const auto &storage : propertiesStorage) { + if (storage.empty()) { + emitter.emitBytes(ArrayRef()); + continue; + } + emitter.emitBytes(ArrayRef(reinterpret_cast(&storage[0]), + storage.size())); + } + } + + /// Returns true if the section is empty. + bool empty() { return propertiesStorage.empty(); } + +private: + /// Emit raw data and returns the offset in the internal buffer. + /// Data are deduplicated and will be copied in the internal buffer only if + /// they don't exist there already. + ssize_t emit(ArrayRef rawProperties) { + // Populate a scratch buffer with the properties size. + SmallVector sizeScratch; + { + EncodingEmitter sizeEmitter; + sizeEmitter.emitVarInt(rawProperties.size()); + llvm::raw_svector_ostream os(sizeScratch); + sizeEmitter.writeTo(os); + } + // Append a new storage to the table now. + size_t index = propertiesStorage.size(); + propertiesStorage.emplace_back(); + std::vector &newStorage = propertiesStorage.back(); + size_t propertiesSize = sizeScratch.size() + rawProperties.size(); + newStorage.reserve(propertiesSize); + newStorage.insert(newStorage.end(), sizeScratch.begin(), sizeScratch.end()); + newStorage.insert(newStorage.end(), rawProperties.begin(), + rawProperties.end()); + + // Try to de-duplicate the new serialized properties. + // If the properties is a duplicate, pop it back from the storage. + auto inserted = propertiesUniquing.insert( + std::make_pair(ArrayRef(newStorage), index)); + if (!inserted.second) + propertiesStorage.pop_back(); + return inserted.first->getSecond(); + } + + /// Storage for properties. + std::vector> propertiesStorage; + SmallVector scratch; + DenseMap, int64_t> propertiesUniquing; + IRNumberingState &numberingState; + StringSectionBuilder &stringSection; + const BytecodeWriterConfig::Impl &config; +}; +} // namespace + /// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need /// to go through an intermediate buffer when interacting with code that wants a /// raw_ostream. @@ -435,11 +558,12 @@ namespace { class BytecodeWriter { public: - BytecodeWriter(Operation *op, const BytecodeWriterConfig::Impl &config) - : numberingState(op), config(config) {} + BytecodeWriter(Operation *op, const BytecodeWriterConfig &config) + : numberingState(op, config), config(config.getImpl()), + propertiesSection(numberingState, stringSection, config.getImpl()) {} /// Write the bytecode for the given root operation. - void write(Operation *rootOp, raw_ostream &os); + LogicalResult write(Operation *rootOp, raw_ostream &os); private: //===--------------------------------------------------------------------===// @@ -455,10 +579,10 @@ //===--------------------------------------------------------------------===// // Operations - void writeBlock(EncodingEmitter &emitter, Block *block); - void writeOp(EncodingEmitter &emitter, Operation *op); - void writeRegion(EncodingEmitter &emitter, Region *region); - void writeIRSection(EncodingEmitter &emitter, Operation *op); + LogicalResult writeBlock(EncodingEmitter &emitter, Block *block); + LogicalResult writeOp(EncodingEmitter &emitter, Operation *op); + LogicalResult writeRegion(EncodingEmitter &emitter, Region *region); + LogicalResult writeIRSection(EncodingEmitter &emitter, Operation *op); //===--------------------------------------------------------------------===// // Resources @@ -470,6 +594,11 @@ void writeStringSection(EncodingEmitter &emitter); + //===--------------------------------------------------------------------===// + // Properties + + void writePropertiesSection(EncodingEmitter &emitter); + //===--------------------------------------------------------------------===// // Helpers @@ -487,10 +616,13 @@ /// Configuration dictating bytecode emission. const BytecodeWriterConfig::Impl &config; + + /// Storage for the properties section + PropertiesSectionBuilder propertiesSection; }; } // namespace -void BytecodeWriter::write(Operation *rootOp, raw_ostream &os) { +LogicalResult BytecodeWriter::write(Operation *rootOp, raw_ostream &os) { EncodingEmitter emitter; // Emit the bytecode file header. This is how we identify the output as a @@ -510,7 +642,8 @@ writeAttrTypeSection(emitter); // Emit the IR section. - writeIRSection(emitter, rootOp); + if (failed(writeIRSection(emitter, rootOp))) + return failure(); // Emit the resources section. writeResourceSection(rootOp, emitter); @@ -518,8 +651,17 @@ // Emit the string section. writeStringSection(emitter); + // Emit the properties section. + if (config.bytecodeVersion >= 5) + writePropertiesSection(emitter); + else if (!propertiesSection.empty()) + return rootOp->emitError( + "unexpected properties emitted incompatible with bytecode <5"); + // Write the generated bytecode to the provided output stream. emitter.writeTo(os); + + return success(); } //===----------------------------------------------------------------------===// @@ -590,7 +732,11 @@ // Emit the referenced operation names grouped by dialect. auto emitOpName = [&](OpNameNumbering &name) { - dialectEmitter.emitVarInt(stringSection.insert(name.name.stripDialect())); + size_t stringId = stringSection.insert(name.name.stripDialect()); + if (config.bytecodeVersion < 5) + dialectEmitter.emitVarInt(stringId); + else + dialectEmitter.emitVarIntWithFlag(stringId, name.name.isRegistered()); }; writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName); @@ -659,7 +805,8 @@ //===----------------------------------------------------------------------===// // Operations -void BytecodeWriter::writeBlock(EncodingEmitter &emitter, Block *block) { +LogicalResult BytecodeWriter::writeBlock(EncodingEmitter &emitter, + Block *block) { ArrayRef args = block->getArguments(); bool hasArgs = !args.empty(); @@ -696,10 +843,12 @@ // Emit the operations within the block. for (Operation &op : *block) - writeOp(emitter, &op); + if (failed(writeOp(emitter, &op))) + return failure(); + return success(); } -void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) { +LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) { emitter.emitVarInt(numberingState.getNumber(op->getName())); // Emit a mask for the operation components. We need to fill this in later @@ -713,10 +862,24 @@ emitter.emitVarInt(numberingState.getNumber(op->getLoc())); // Emit the attributes of this operation. - DictionaryAttr attrs = op->getAttrDictionary(); + DictionaryAttr attrs = op->getDiscardableAttrDictionary(); + // Allow deployment to version <5 by merging inherent attribute with the + // discardable ones. We should fail if there are any conflicts. + if (config.bytecodeVersion < 5) + attrs = op->getAttrDictionary(); if (!attrs.empty()) { opEncodingMask |= bytecode::OpEncodingMask::kHasAttrs; - emitter.emitVarInt(numberingState.getNumber(op->getAttrDictionary())); + emitter.emitVarInt(numberingState.getNumber(attrs)); + } + + // Emit the properties of this operation, for now we still support deployment + // to version <5. + if (config.bytecodeVersion >= 5) { + std::optional propertiesId = propertiesSection.emit(op); + if (propertiesId.has_value()) { + opEncodingMask |= bytecode::OpEncodingMask::kHasProperties; + emitter.emitVarInt(*propertiesId); + } } // Emit the result types of the operation. @@ -768,15 +931,18 @@ // If the region is not isolated from above, or we are emitting bytecode // targeting version <2, we don't use a section. if (!isIsolatedFromAbove || config.bytecodeVersion < 2) { - writeRegion(emitter, ®ion); + if (failed(writeRegion(emitter, ®ion))) + return failure(); continue; } EncodingEmitter regionEmitter; - writeRegion(regionEmitter, ®ion); + if (failed(writeRegion(regionEmitter, ®ion))) + return failure(); emitter.emitSection(bytecode::Section::kIR, std::move(regionEmitter)); } } + return success(); } void BytecodeWriter::writeUseListOrders(EncodingEmitter &emitter, @@ -867,11 +1033,14 @@ } } -void BytecodeWriter::writeRegion(EncodingEmitter &emitter, Region *region) { +LogicalResult BytecodeWriter::writeRegion(EncodingEmitter &emitter, + Region *region) { // If the region is empty, we only need to emit the number of blocks (which is // zero). - if (region->empty()) - return emitter.emitVarInt(/*numBlocks*/ 0); + if (region->empty()) { + emitter.emitVarInt(/*numBlocks*/ 0); + return success(); + } // Emit the number of blocks and values within the region. unsigned numBlocks, numValues; @@ -881,10 +1050,13 @@ // Emit the blocks within the region. for (Block &block : *region) - writeBlock(emitter, &block); + if (failed(writeBlock(emitter, &block))) + return failure(); + return success(); } -void BytecodeWriter::writeIRSection(EncodingEmitter &emitter, Operation *op) { +LogicalResult BytecodeWriter::writeIRSection(EncodingEmitter &emitter, + Operation *op) { EncodingEmitter irEmitter; // Write the IR section the same way as a block with no arguments. Note that @@ -893,9 +1065,11 @@ irEmitter.emitVarIntWithFlag(/*numOps*/ 1, /*hasArgs*/ false); // Emit the operations. - writeOp(irEmitter, op); + if (failed(writeOp(irEmitter, op))) + return failure(); emitter.emitSection(bytecode::Section::kIR, std::move(irEmitter)); + return success(); } //===----------------------------------------------------------------------===// @@ -1011,14 +1185,22 @@ emitter.emitSection(bytecode::Section::kString, std::move(stringEmitter)); } +//===----------------------------------------------------------------------===// +// Properties + +void BytecodeWriter::writePropertiesSection(EncodingEmitter &emitter) { + EncodingEmitter propertiesEmitter; + propertiesSection.write(propertiesEmitter); + emitter.emitSection(bytecode::Section::kProperties, + std::move(propertiesEmitter)); +} + //===----------------------------------------------------------------------===// // Entry Points //===----------------------------------------------------------------------===// LogicalResult mlir::writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config) { - BytecodeWriter writer(op, config.getImpl()); - writer.write(op, os); - // Currently there is no failure case. - return success(); + BytecodeWriter writer(op, config); + return writer.write(op, os); } diff --git a/mlir/lib/Bytecode/Writer/CMakeLists.txt b/mlir/lib/Bytecode/Writer/CMakeLists.txt --- a/mlir/lib/Bytecode/Writer/CMakeLists.txt +++ b/mlir/lib/Bytecode/Writer/CMakeLists.txt @@ -8,4 +8,5 @@ LINK_LIBS PUBLIC MLIRIR MLIRSupport + MLIRBytecodeOpInterface ) diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h --- a/mlir/lib/Bytecode/Writer/IRNumbering.h +++ b/mlir/lib/Bytecode/Writer/IRNumbering.h @@ -18,6 +18,8 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringMap.h" +#include "llvm/CodeGen/NonRelocatableStringpool.h" +#include namespace mlir { class BytecodeDialectInterface; @@ -133,7 +135,7 @@ /// emission. class IRNumberingState { public: - IRNumberingState(Operation *op); + IRNumberingState(Operation *op, const BytecodeWriterConfig &config); /// Return the numbered dialects. auto getDialects() { @@ -241,6 +243,9 @@ /// The next value ID to assign when numbering. unsigned nextValueID = 0; + + // Configuration: useful to query the required version to emit. + const BytecodeWriterConfig &config; }; } // namespace detail } // namespace bytecode diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -8,6 +8,7 @@ #include "IRNumbering.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" @@ -24,6 +25,10 @@ NumberingDialectWriter(IRNumberingState &state) : state(state) {} void writeAttribute(Attribute attr) override { state.number(attr); } + void writeOptionalAttribute(Attribute attr) override { + if (attr) + state.number(attr); + } void writeType(Type type) override { state.number(type); } void writeResourceHandle(const AsmDialectResourceHandle &resource) override { state.number(resource.getDialect(), resource); @@ -106,7 +111,9 @@ value->number = idx; } -IRNumberingState::IRNumberingState(Operation *op) { +IRNumberingState::IRNumberingState(Operation *op, + const BytecodeWriterConfig &config) + : config(config) { // Compute a global operation ID numbering according to the pre-order walk of // the IR. This is used as reference to construct use-list orders. unsigned operationID = 0; @@ -276,10 +283,30 @@ } // Only number the operation's dictionary if it isn't empty. - DictionaryAttr dictAttr = op.getAttrDictionary(); + DictionaryAttr dictAttr = op.getDiscardableAttrDictionary(); + // Prior to version 5 we need to number also the merged dictionnary + // containing both the inherent and discardable attribute. + if (config.getDesiredBytecodeVersion() < 5) + dictAttr = op.getAttrDictionary(); if (!dictAttr.empty()) number(dictAttr); + // Visit the operation properties (if any) to make sure referenced attributes + // are numbered. + if (config.getDesiredBytecodeVersion() >= 5 && + op.getPropertiesStorageSize()) { + if (op.isRegistered()) { + // Operation that have properties *must* implement this interface. + auto iface = cast(op); + NumberingDialectWriter writer(*this); + iface.writeProperties(writer); + } else { + // Unregistered op are storing properties as an optional attribute. + if (Attribute prop = *op.getPropertiesStorage().as()) + number(prop); + } + } + number(op.getLoc()); } diff --git a/mlir/lib/TableGen/Property.cpp b/mlir/lib/TableGen/Property.cpp --- a/mlir/lib/TableGen/Property.cpp +++ b/mlir/lib/TableGen/Property.cpp @@ -72,6 +72,16 @@ return getValueAsString(init); } +StringRef Property::getReadFromMlirBytecodeCall() const { + const auto *init = def->getValueInit("readFromMlirBytecode"); + return getValueAsString(init); +} + +StringRef Property::getWriteToMlirBytecodeCall() const { + const auto *init = def->getValueInit("writeToMlirBytecode"); + return getValueAsString(init); +} + StringRef Property::getHashPropertyCall() const { return getValueAsString(def->getValueInit("hashProperty")); } diff --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir --- a/mlir/test/Bytecode/invalid/invalid-structure.mlir +++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir @@ -9,7 +9,7 @@ //===--------------------------------------------------------------------===// // RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION -// VERSION: bytecode version 127 is newer than the current version +// VERSION: bytecode version 127 is newer than the current version {{[0-9]+}} //===--------------------------------------------------------------------===// // Producer diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -28,12 +28,14 @@ #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Reducer/ReductionPatternInterface.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" +#include #include #include @@ -62,6 +64,44 @@ return hash_value(StringRef(content)); } +static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader, + MyPropStruct &prop) { + StringRef str; + if (failed(reader.readString(str))) + return failure(); + prop.content = str.str(); + return success(); +} + +static void writeToMlirBytecode(::mlir::DialectBytecodeWriter &writer, + MyPropStruct &prop) { + writer.writeOwnedString(prop.content); +} + +static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader, + MutableArrayRef prop) { + uint64_t size; + if (failed(reader.readVarInt(size))) + return failure(); + if (size != prop.size()) + return reader.emitError("array size mismach when reading properties: ") + << size << " vs expected " << prop.size(); + for (auto &elt : prop) { + uint64_t value; + if (failed(reader.readVarInt(value))) + return failure(); + elt = value; + } + return success(); +} + +static void writeToMlirBytecode(::mlir::DialectBytecodeWriter &writer, + ArrayRef prop) { + writer.writeVarInt(prop.size()); + for (auto elt : prop) + writer.writeVarInt(elt); +} + static LogicalResult setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr, InFlightDiagnostic *diagnostic); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3324,7 +3324,7 @@ def TestOpWithProperties : TEST_Op<"with_properties"> { let assemblyFormat = "prop-dict attr-dict"; let arguments = (ins - Property<"int64_t">:$a, + IntProperty<"int64_t">:$a, StrAttr:$b, // Attributes can directly be used here. ArrayProperty<"int64_t", 4>:$array // example of an array ); @@ -3369,8 +3369,31 @@ const Properties &prop); static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); + static ::mlir::LogicalResult readFromMlirBytecode( + ::mlir::DialectBytecodeReader &, + test::PropertiesWithCustomPrint &prop); + static void writeToMlirBytecode( + ::mlir::DialectBytecodeWriter &, + const test::PropertiesWithCustomPrint &prop); }]; let extraClassDefinition = [{ + ::mlir::LogicalResult TestOpWithNiceProperties::readFromMlirBytecode( + ::mlir::DialectBytecodeReader &reader, + test::PropertiesWithCustomPrint &prop) { + StringRef label; + uint64_t value; + if (failed(reader.readString(label)) || failed(reader.readVarInt(value))) + return failure(); + prop.label = std::make_shared(label.str()); + prop.value = value; + return success(); + } + void TestOpWithNiceProperties::writeToMlirBytecode( + ::mlir::DialectBytecodeWriter &writer, + const test::PropertiesWithCustomPrint &prop) { + writer.writeOwnedString(*prop.label); + writer.writeVarInt(prop.value); + } void TestOpWithNiceProperties::printProperties(::mlir::MLIRContext *ctx, ::mlir::OpAsmPrinter &p, const Properties &prop) { customPrintProperties(p, prop.prop); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h @@ -14,6 +14,7 @@ #ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_H #define MLIR_TESTTRANSFORMDIALECTEXTENSION_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1103,6 +1103,21 @@ "getDiag")) ->body(); + auto &readPropertiesMethod = + opClass + .addStaticMethod( + "::mlir::LogicalResult", "readProperties", + MethodParameter("::mlir::DialectBytecodeReader &", "reader"), + MethodParameter("::mlir::OperationState &", "state")) + ->body(); + + auto &writePropertiesMethod = + opClass + .addMethod( + "void", "writeProperties", + MethodParameter("::mlir::DialectBytecodeWriter &", "writer")) + ->body(); + opClass.declare("Properties", "FoldAdaptor::Properties"); // Convert the property to the attribute form. @@ -1304,6 +1319,66 @@ } } verifyInherentAttrsMethod << " return ::mlir::success();"; + + // Populate bytecode serialization logic. + readPropertiesMethod + << " auto &prop = state.getOrAddProperties(); (void)prop;"; + writePropertiesMethod << " auto &prop = getProperties(); (void)prop;\n"; + for (const auto &attrOrProp : attrOrProperties) { + if (const auto *namedProperty = + attrOrProp.dyn_cast()) { + StringRef name = namedProperty->name; + FmtContext fctx; + fctx.addSubst("_reader", "reader") + .addSubst("_writer", "writer") + .addSubst("_storage", propertyStorage); + readPropertiesMethod << formatv( + R"( + {{ + auto &propStorage = prop.{0}; + auto readProp = [&]() { + {1}; + return ::mlir::success(); + }; + if (failed(readProp())) + return ::mlir::failure(); + } +)", + name, + tgfmt(namedProperty->prop.getReadFromMlirBytecodeCall(), &fctx)); + writePropertiesMethod << formatv( + R"( + {{ + auto &propStorage = prop.{0}; + {1}; + } +)", + name, tgfmt(namedProperty->prop.getWriteToMlirBytecodeCall(), &fctx)); + continue; + } + const auto *namedAttr = attrOrProp.dyn_cast(); + StringRef name = namedAttr->attrName; + if (namedAttr->isRequired) { + readPropertiesMethod << formatv(R"( + if (failed(reader.readAttribute(prop.{0}))) + return failure(); +)", + name); + writePropertiesMethod + << formatv(" writer.writeAttribute(prop.{0});\n", name); + } else { + readPropertiesMethod << formatv(R"( + if (failed(reader.readOptionalAttribute(prop.{0}))) + return failure(); +)", + name); + writePropertiesMethod << formatv(R"( + writer.writeOptionalAttribute(prop.{0}); +)", + name); + } + } + readPropertiesMethod << " return success();"; } void OpEmitter::genAttrGetters() { @@ -3300,6 +3375,9 @@ // native/interface traits and after all the traits with `StructuralOpTrait`. opClass.addTrait("::mlir::OpTrait::OpInvariants"); + if (emitHelper.hasProperties()) + opClass.addTrait("::mlir::BytecodeOpInterface::Trait"); + // Add the native and interface traits. for (const auto &trait : op.getTraits()) { if (auto *opTrait = dyn_cast(&trait)) {