diff --git a/mlir/examples/standalone/include/Standalone/StandaloneDialect.h b/mlir/examples/standalone/include/Standalone/StandaloneDialect.h --- a/mlir/examples/standalone/include/Standalone/StandaloneDialect.h +++ b/mlir/examples/standalone/include/Standalone/StandaloneDialect.h @@ -9,6 +9,7 @@ #ifndef STANDALONE_STANDALONEDIALECT_H #define STANDALONE_STANDALONEDIALECT_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Dialect.h" #include "Standalone/StandaloneOpsDialect.h.inc" diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -74,6 +74,10 @@ /// Read a reference to the given attribute. virtual LogicalResult readAttribute(Attribute &result) = 0; + /// Read an optional reference to the given attribute. Returns success even if + /// the Attribute isn't present. + virtual LogicalResult readOptionalAttribute(Attribute &attr) = 0; + template LogicalResult readAttributes(SmallVectorImpl &attrs) { return readList(attrs, [this](T &attr) { return readAttribute(attr); }); @@ -88,6 +92,18 @@ return emitError() << "expected " << llvm::getTypeName() << ", but got: " << baseResult; } + template + LogicalResult readOptionalAttribute(T &result) { + Attribute baseResult; + if (failed(readOptionalAttribute(baseResult))) + return failure(); + if (!baseResult) + return success(); + if ((result = dyn_cast(baseResult))) + return success(); + return emitError() << "expected " << llvm::getTypeName() + << ", but got: " << baseResult; + } /// Read a reference to the given type. virtual LogicalResult readType(Type &result) = 0; @@ -179,6 +195,7 @@ /// Write a reference to the given attribute. virtual void writeAttribute(Attribute attr) = 0; + virtual void writeOptionalAttribute(Attribute attr) = 0; template void writeAttributes(ArrayRef attrs) { writeList(attrs, [this](T attr) { writeAttribute(attr); }); diff --git a/mlir/include/mlir/Bytecode/BytecodeOpInterface.h b/mlir/include/mlir/Bytecode/BytecodeOpInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Bytecode/BytecodeOpInterface.h @@ -0,0 +1,27 @@ +//===- CallInterfaces.h - Call Interfaces for MLIR --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the BytecodeOpInterface defined in +// `BytecodeOpInterface.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BYTECODE_BYTECODEOPINTERFACES_H +#define MLIR_BYTECODE_BYTECODEOPINTERFACES_H + +#include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Bytecode/BytecodeReader.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LogicalResult.h" + +/// Include the generated interface declarations. +#include "mlir/Bytecode/BytecodeOpInterface.h.inc" + +#endif // MLIR_BYTECODE_BYTECODEOPINTERFACES_H diff --git a/mlir/include/mlir/Bytecode/BytecodeOpInterface.td b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Bytecode/BytecodeOpInterface.td @@ -0,0 +1,42 @@ +//===- BytecodeOpInterface.td - Bytecode OpInterface -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains an interface for operation interactions with the bytecode +// serialization/deserialization, in particular for properties. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BYTECODE_BYTECODEOPINTERFACES +#define MLIR_BYTECODE_BYTECODEOPINTERFACES + +include "mlir/IR/OpBase.td" + +// `BytecodeOpInterface` +def BytecodeOpInterface : OpInterface<"BytecodeOpInterface"> { + let description = [{ + This interface allows operation to control the serialization of their + properties. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + StaticInterfaceMethod<[{ + // TODO + }], + "LogicalResult", "readProperties", (ins "::mlir::DialectBytecodeReader &":$reader, + "::mlir::OperationState &":$state) + >, + InterfaceMethod<[{ + // TODO + }], + "void", "writeProperties", (ins "::mlir::DialectBytecodeWriter &":$writer) + >, + ]; +} + +#endif // MLIR_BYTECODE_BYTECODEOPINTERFACES diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h --- a/mlir/include/mlir/Bytecode/BytecodeWriter.h +++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h @@ -46,6 +46,9 @@ /// is returned by bytecode writer entry point. void setDesiredBytecodeVersion(int64_t bytecodeVersion); + /// Get the set desired bytecode version to emit. + int64_t getDesiredBytecodeVersion() const; + //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Bytecode/CMakeLists.txt b/mlir/include/mlir/Bytecode/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Bytecode/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_interface(BytecodeOpInterface) diff --git a/mlir/include/mlir/Bytecode/Encoding.h b/mlir/include/mlir/Bytecode/Encoding.h --- a/mlir/include/mlir/Bytecode/Encoding.h +++ b/mlir/include/mlir/Bytecode/Encoding.h @@ -29,7 +29,7 @@ kMinSupportedVersion = 0, /// The current bytecode version. - kVersion = 3, + kVersion = 4, /// An arbitrary value used to fill alignment padding. kAlignmentByte = 0xCB, @@ -69,8 +69,11 @@ /// This section contains the versions of each dialect. kDialectVersions = 7, + /// This section contains the properties for the operations. + kProperties = 8, + /// The total number of section types. - kNumSections = 8, + kNumSections = 9, }; } // namespace Section @@ -90,6 +93,7 @@ kHasSuccessors = 0b00001000, kHasInlineRegions = 0b00010000, kHasUseListOrders = 0b00100000, + kHasProperties = 0b01000000, // clang-format on }; } // namespace OpEncodingMask diff --git a/mlir/include/mlir/CMakeLists.txt b/mlir/include/mlir/CMakeLists.txt --- a/mlir/include/mlir/CMakeLists.txt +++ b/mlir/include/mlir/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Bytecode) add_subdirectory(Conversion) add_subdirectory(Dialect) add_subdirectory(IR) diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_AMDGPU_IR_AMDGPUDIALECT_H_ #define MLIR_DIALECT_AMDGPU_IR_AMDGPUDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h --- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h +++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_AMX_AMXDIALECT_H_ #define MLIR_DIALECT_AMX_AMXDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h --- a/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h +++ b/mlir/include/mlir/Dialect/Affine/TransformOps/AffineTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_AFFINE_TRANSFORMOPS_AFFINETRANSFORMOPS_H #define MLIR_DIALECT_AFFINE_TRANSFORMOPS_AFFINETRANSFORMOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_ARITH_IR_ARITH_H_ #define MLIR_DIALECT_ARITH_IR_ARITH_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h --- a/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h +++ b/mlir/include/mlir/Dialect/ArmNeon/ArmNeonDialect.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_ #define MLIR_DIALECT_ARMNEON_ARMNEONDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H #define MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h --- a/mlir/include/mlir/Dialect/Async/IR/Async.h +++ b/mlir/include/mlir/Dialect/Async/IR/Async.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_ASYNC_IR_ASYNC_H #define MLIR_DIALECT_ASYNC_IR_ASYNC_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Async/IR/AsyncTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATION_H_ #define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATION_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Interfaces/CopyOpInterface.h" diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H #define MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" diff --git a/mlir/include/mlir/Dialect/Complex/IR/Complex.h b/mlir/include/mlir/Dialect/Complex/IR/Complex.h --- a/mlir/include/mlir/Dialect/Complex/IR/Complex.h +++ b/mlir/include/mlir/Dialect/Complex/IR/Complex.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_ #define MLIR_DIALECT_COMPLEX_IR_COMPLEX_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/InferTypeOpInterface.h" diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlow.h b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlow.h --- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlow.h +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlow.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H #define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOW_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Dialect.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.h.inc" diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h --- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h +++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H #define MLIR_DIALECT_CONTROLFLOW_IR_CONTROLFLOWOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_EMITC_IR_EMITC_H #define MLIR_DIALECT_EMITC_IR_EMITC_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.h b/mlir/include/mlir/Dialect/Func/IR/FuncOps.h --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.h +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_FUNC_IR_OPS_H #define MLIR_DIALECT_FUNC_IR_OPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_GPU_IR_GPUDIALECT_H #define MLIR_DIALECT_GPU_IR_GPUDIALECT_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h b/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h --- a/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h +++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDL.h @@ -13,11 +13,13 @@ #ifndef MLIR_DIALECT_IRDL_IR_IRDL_H_ #define MLIR_DIALECT_IRDL_IR_IRDL_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h" #include "mlir/Dialect/IRDL/IR/IRDLTraits.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" + #include // Forward declaration. diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h --- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h +++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_INDEX_IR_INDEXOPS_H #define MLIR_DIALECT_INDEX_IR_INDEXOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Index/IR/IndexAttrs.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ #define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ #define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h @@ -22,6 +22,7 @@ #ifndef MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_ #define MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h --- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_LINALG_IR_LINALG_H #define MLIR_DIALECT_LINALG_IR_LINALG_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h @@ -8,6 +8,7 @@ #ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_ #define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.h" #include "mlir/Dialect/MLProgram/IR/MLProgramTypes.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/Math/IR/Math.h b/mlir/include/mlir/Dialect/Math/IR/Math.h --- a/mlir/include/mlir/Dialect/Math/IR/Math.h +++ b/mlir/include/mlir/Dialect/Math/IR/Math.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_MATH_IR_MATH_H_ #define MLIR_DIALECT_MATH_IR_MATH_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_MEMREF_IR_MEMREF_H_ #define MLIR_DIALECT_MEMREF_IR_MEMREF_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/Dialect.h" @@ -21,6 +22,7 @@ #include "mlir/Interfaces/ShapedOpInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" + #include namespace mlir { diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H #define MLIR_DIALECT_MEMREF_TRANSFORMOPS_MEMREFTRANSFORMOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_NVGPU_NVGPUDIALECT_H_ #define MLIR_DIALECT_NVGPU_NVGPUDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h --- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h @@ -18,6 +18,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.h.inc" #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.h.inc" #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.h.inc" diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.h b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.h --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.h +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_PDL_IR_PDLOPS_H_ #define MLIR_DIALECT_PDL_IR_PDLOPS_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_ #define MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/FunctionInterfaces.h" diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H #define MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/OpImplementation.h" diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_ #define MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_SHAPE_IR_SHAPE_H #define MLIR_DIALECT_SHAPE_IR_SHAPE_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/FunctionInterfaces.h" diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_ #define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSOR_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_TENSOR_IR_TENSOR_H_ #define MLIR_DIALECT_TENSOR_IR_TENSOR_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_TOSA_IR_TOSAOPS_H #define MLIR_DIALECT_TOSA_IR_TOSAOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/InferTypeOpInterface.h" diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" #include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h" #include "mlir/IR/AffineMap.h" diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h --- a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h +++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_ #define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_ +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/lib/Bytecode/BytecodeOpInterface.cpp b/mlir/lib/Bytecode/BytecodeOpInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/BytecodeOpInterface.cpp @@ -0,0 +1,17 @@ +//===- CallInterfaces.cpp - ControlFlow Interfaces ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Bytecode/BytecodeOpInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// BytecodeOpInterface +//===----------------------------------------------------------------------===// + +#include "mlir/Bytecode/BytecodeOpInterface.cpp.inc" diff --git a/mlir/lib/Bytecode/CMakeLists.txt b/mlir/lib/Bytecode/CMakeLists.txt --- a/mlir/lib/Bytecode/CMakeLists.txt +++ b/mlir/lib/Bytecode/CMakeLists.txt @@ -1,2 +1,13 @@ add_subdirectory(Reader) add_subdirectory(Writer) + +add_mlir_library(MLIRBytecodeOpInterface + BytecodeOpInterface.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Bytecode + + LINK_LIBS PUBLIC + MLIRIR + MLIRSupport + ) diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -11,14 +11,17 @@ #include "mlir/Bytecode/BytecodeReader.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Bytecode/Encoding.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Verifier.h" #include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallString.h" @@ -26,6 +29,7 @@ #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SourceMgr.h" +#include #include #include #include @@ -54,13 +58,15 @@ return "ResourceOffset (6)"; case bytecode::Section::kDialectVersions: return "DialectVersions (7)"; + case bytecode::Section::kProperties: + return "Properties (8)"; default: return ("Unknown (" + Twine(static_cast(sectionID)) + ")").str(); } } /// Returns true if the given top-level section ID is optional. -static bool isSectionOptional(bytecode::Section::ID sectionID) { +static bool isSectionOptional(bytecode::Section::ID sectionID, int version) { switch (sectionID) { case bytecode::Section::kString: case bytecode::Section::kDialect: @@ -72,6 +78,10 @@ case bytecode::Section::kResourceOffset: case bytecode::Section::kDialectVersions: return true; + case bytecode::Section::kProperties: + if (version < 4) + return true; + return false; default: llvm_unreachable("unknown section ID"); } @@ -777,6 +787,18 @@ result = resolveAttribute(attrIdx); return success(!!result); } + LogicalResult parseOptionalAttribute(EncodingReader &reader, + Attribute &result) { + uint64_t attrIdx; + bool flag; + if (failed(reader.parseVarIntWithFlag(attrIdx, flag))) + return failure(); + if (!flag) + return success(); + result = resolveAttribute(attrIdx); + return success(!!result); + } + LogicalResult parseType(EncodingReader &reader, Type &result) { uint64_t typeIdx; if (failed(reader.parseVarInt(typeIdx))) @@ -856,7 +878,9 @@ LogicalResult readAttribute(Attribute &result) override { return attrTypeReader.parseAttribute(reader, result); } - + LogicalResult readOptionalAttribute(Attribute &result) override { + return attrTypeReader.parseOptionalAttribute(reader, result); + } LogicalResult readType(Type &result) override { return attrTypeReader.parseType(reader, result); } @@ -943,6 +967,92 @@ ResourceSectionReader &resourceReader; EncodingReader &reader; }; + +/// Wraps the properties section and handles reading properties out of it. +class PropertiesSectionReader { +public: + /// Initialize the string section reader with the given section data. + LogicalResult initialize(Location fileLoc, ArrayRef sectionData) { + if (sectionData.empty()) + return success(); + EncodingReader propReader(sectionData, fileLoc); + size_t count; + if (failed(propReader.parseVarInt(count))) + return failure(); + // Parse the raw properties buffer. + if (failed(propReader.parseBytes(propReader.size(), propertiesBuffers))) + return failure(); + + EncodingReader offsetsReader(propertiesBuffers, fileLoc); + offsetTable.reserve(count); + for (auto idx : llvm::seq(0, count)) { + (void)idx; + offsetTable.push_back(propertiesBuffers.size() - offsetsReader.size()); + ArrayRef rawProperties; + uint64_t dataSize; + if (failed(offsetsReader.parseVarInt(dataSize)) || + failed(offsetsReader.parseBytes(dataSize, rawProperties))) + return failure(); + } + if (!offsetsReader.empty()) + return offsetsReader.emitError() + << "Broken properties section: didn't exhaust the offsets table"; + return success(); + } + + LogicalResult read(Location fileLoc, DialectReader &dialectReader, + OperationName *opName, OperationState &opState) { + size_t propertiesIdx; + if (failed(dialectReader.readVarInt(propertiesIdx))) + return failure(); + if (propertiesIdx >= offsetTable.size()) + return dialectReader.emitError("Properties idx out-of-bound for ") + << opName->getStringRef(); + size_t propertiesOffset = offsetTable[propertiesIdx]; + if (propertiesIdx >= propertiesBuffers.size()) + return dialectReader.emitError("Properties offset out-of-bound for ") + << opName->getStringRef(); + + // Acquire the sub-buffer that represent the requested properties. + ArrayRef rawProperties; + { + // "Seek" to the requested offset by getting a new reader with the right + // sub-buffer. + EncodingReader reader(propertiesBuffers.drop_front(propertiesOffset), + fileLoc); + // Properties are stored as a sequence of {size + raw_data}. + if (failed( + dialectReader.withEncodingReader(reader).readBlob(rawProperties))) + return failure(); + } + // Setup a new reader to read from the `rawProperties` sub-buffer. + EncodingReader reader( + StringRef(rawProperties.begin(), rawProperties.size()), fileLoc); + DialectReader propReader = dialectReader.withEncodingReader(reader); + + auto *iface = opName->getInterface(); + if (iface) { + if (failed(iface->readProperties(propReader, opState))) + return failure(); + return success(); + } + if (opName->isRegistered()) + return propReader.emitError( + "Has properties but missing BytecodeOpInterface for ") + << opName->getStringRef(); + // Unregistered op are storing properties as an attribute. + if (failed(propReader.readAttribute(opState.propertiesAttr))) + return failure(); + return success(); + } + +private: + /// The properties buffer referenced within the bytecode file. + ArrayRef propertiesBuffers; + + /// Table of offset in the buffer above. + SmallVector offsetTable; +}; } // namespace LogicalResult @@ -1180,7 +1290,9 @@ lazyLoadableOps.erase(it->getSecond()); lazyLoadableOpsMap.erase(it); auto result = parseRegions(regionStack, regionStack.back()); - assert(regionStack.empty()); + assert((regionStack.empty() || failed(result)) && + "broken invariant: regionStack should be empty when parseRegions " + "succeeds"); return result; } @@ -1384,6 +1496,9 @@ /// The table of strings referenced within the bytecode file. StringSectionReader stringReader; + /// The table of properties referenced by the operation in the bytecode file. + PropertiesSectionReader propertiesReader; + /// The current set of available IR value scopes. std::vector valueScopes; @@ -1452,7 +1567,7 @@ // Check that all of the required sections were found. for (int i = 0; i < bytecode::Section::kNumSections; ++i) { bytecode::Section::ID sectionID = static_cast(i); - if (!sectionDatas[i] && !isSectionOptional(sectionID)) { + if (!sectionDatas[i] && !isSectionOptional(sectionID, version)) { return reader.emitError("missing data for top-level section: ", ::toString(sectionID)); } @@ -1463,6 +1578,12 @@ fileLoc, *sectionDatas[bytecode::Section::kString]))) return failure(); + // Process the properties section. + if (sectionDatas[bytecode::Section::kProperties] && + failed(propertiesReader.initialize( + fileLoc, *sectionDatas[bytecode::Section::kProperties]))) + return failure(); + // Process the dialect section. if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) return failure(); @@ -1987,6 +2108,14 @@ opState.attributes = dictAttr; } + if (opMask & bytecode::OpEncodingMask::kHasProperties) { + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + reader); + if (failed( + propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) + return failure(); + } + /// Parse the results of the operation. if (opMask & bytecode::OpEncodingMask::kHasResults) { uint64_t numResults; diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -9,11 +9,21 @@ #include "mlir/Bytecode/BytecodeWriter.h" #include "IRNumbering.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Bytecode/Encoding.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/CachedHashString.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" +#include +#include +#include +#include #define DEBUG_TYPE "mlir-bytecode-writer" @@ -58,6 +68,10 @@ std::min(bytecodeVersion, bytecode::kVersion); } +int64_t BytecodeWriterConfig::getDesiredBytecodeVersion() const { + return impl->bytecodeVersion; +} + //===----------------------------------------------------------------------===// // EncodingEmitter //===----------------------------------------------------------------------===// @@ -318,6 +332,14 @@ void writeAttribute(Attribute attr) override { emitter.emitVarInt(numberingState.getNumber(attr)); } + void writeOptionalAttribute(Attribute attr) override { + if (!attr) { + emitter.emitVarInt(0); + return; + } + emitter.emitVarIntWithFlag(numberingState.getNumber(attr), true); + } + void writeType(Type type) override { emitter.emitVarInt(numberingState.getNumber(type)); } @@ -382,6 +404,102 @@ StringSectionBuilder &stringSection; }; +namespace { +class PropertiesSectionBuilder { +public: + PropertiesSectionBuilder(IRNumberingState &numberingState, + StringSectionBuilder &stringSection, + const BytecodeWriterConfig::Impl &config) + : numberingState(numberingState), stringSection(stringSection), + config(config) {} + + /// Emit the op properties in the properties section and return the index of + /// the properties within the section. Return -1 if no properties was emitted. + ssize_t emit(Operation *op) { + EncodingEmitter propertiesEmitter; + if (!op->getPropertiesStorageSize()) + return -1; + if (!op->isRegistered()) { + // Unregistered op are storing properties as an optional attribute. + Attribute prop = *op->getPropertiesStorage().as(); + if (!prop) + return -1; + EncodingEmitter sizeEmitter; + sizeEmitter.emitVarInt(numberingState.getNumber(prop)); + scratch.clear(); + llvm::raw_svector_ostream os(scratch); + sizeEmitter.writeTo(os); + return emit(scratch); + } + + EncodingEmitter emitter; + DialectWriter propertiesWriter(config.bytecodeVersion, emitter, + numberingState, stringSection); + auto iface = cast(op); + iface.writeProperties(propertiesWriter); + scratch.clear(); + llvm::raw_svector_ostream os(scratch); + emitter.writeTo(os); + return emit(scratch); + } + + /// Write the current set of properties to the given emitter. + void write(EncodingEmitter &emitter) { + emitter.emitVarInt(propertiesStorage.size()); + if (propertiesStorage.empty()) + return; + for (const auto &storage : propertiesStorage) { + if (storage.empty()) { + emitter.emitBytes(ArrayRef()); + continue; + } + emitter.emitBytes(ArrayRef(reinterpret_cast(&storage[0]), + storage.size())); + } + } + +private: + // Emit raw data and returns the offset in the internal buffer. + // Data are deduplicated and will be copied in the internal buffer only if + // they don't exist there already. + ssize_t emit(ArrayRef rawProperties) { + // populate a scratch buffer with the properties size. + SmallVector sizeScratch; + { + EncodingEmitter sizeEmitter; + sizeEmitter.emitVarInt(rawProperties.size()); + llvm::raw_svector_ostream os(sizeScratch); + sizeEmitter.writeTo(os); + } + // Append a new storage to the table now. + size_t index = propertiesStorage.size(); + propertiesStorage.emplace_back(); + std::vector &newStorage = propertiesStorage.back(); + size_t propertiesSize = sizeScratch.size() + rawProperties.size(); + newStorage.reserve(propertiesSize); + newStorage.insert(newStorage.end(), sizeScratch.begin(), sizeScratch.end()); + newStorage.insert(newStorage.end(), rawProperties.begin(), + rawProperties.end()); + + // Try to de-duplicate the new serialized properties. + // If the properties is a duplicate, pop it back from the storage. + auto inserted = propertiesUniquing.insert( + std::make_pair(ArrayRef(newStorage), index)); + if (!inserted.second) + propertiesStorage.pop_back(); + return inserted.first->getSecond(); + } + + /// Storage for properties. + std::vector> propertiesStorage; + SmallVector scratch; + DenseMap, int64_t> propertiesUniquing; + IRNumberingState &numberingState; + StringSectionBuilder &stringSection; + const BytecodeWriterConfig::Impl &config; +}; +} // namespace + /// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need /// to go through an intermediate buffer when interacting with code that wants a /// raw_ostream. @@ -435,11 +553,12 @@ namespace { class BytecodeWriter { public: - BytecodeWriter(Operation *op, const BytecodeWriterConfig::Impl &config) - : numberingState(op), config(config) {} + BytecodeWriter(Operation *op, const BytecodeWriterConfig &config) + : numberingState(op, config), config(config.getImpl()), + propertiesSection(numberingState, stringSection, config.getImpl()) {} /// Write the bytecode for the given root operation. - void write(Operation *rootOp, raw_ostream &os); + LogicalResult write(Operation *rootOp, raw_ostream &os); private: //===--------------------------------------------------------------------===// @@ -455,10 +574,10 @@ //===--------------------------------------------------------------------===// // Operations - void writeBlock(EncodingEmitter &emitter, Block *block); - void writeOp(EncodingEmitter &emitter, Operation *op); - void writeRegion(EncodingEmitter &emitter, Region *region); - void writeIRSection(EncodingEmitter &emitter, Operation *op); + LogicalResult writeBlock(EncodingEmitter &emitter, Block *block); + LogicalResult writeOp(EncodingEmitter &emitter, Operation *op); + LogicalResult writeRegion(EncodingEmitter &emitter, Region *region); + LogicalResult writeIRSection(EncodingEmitter &emitter, Operation *op); //===--------------------------------------------------------------------===// // Resources @@ -470,6 +589,11 @@ void writeStringSection(EncodingEmitter &emitter); + //===--------------------------------------------------------------------===// + // Properties + + void writePropertiesSection(EncodingEmitter &emitter); + //===--------------------------------------------------------------------===// // Helpers @@ -487,10 +611,13 @@ /// Configuration dictating bytecode emission. const BytecodeWriterConfig::Impl &config; + + /// Storage for the properties section + PropertiesSectionBuilder propertiesSection; }; } // namespace -void BytecodeWriter::write(Operation *rootOp, raw_ostream &os) { +LogicalResult BytecodeWriter::write(Operation *rootOp, raw_ostream &os) { EncodingEmitter emitter; // Emit the bytecode file header. This is how we identify the output as a @@ -510,7 +637,8 @@ writeAttrTypeSection(emitter); // Emit the IR section. - writeIRSection(emitter, rootOp); + if (failed(writeIRSection(emitter, rootOp))) + return failure(); // Emit the resources section. writeResourceSection(rootOp, emitter); @@ -518,8 +646,13 @@ // Emit the string section. writeStringSection(emitter); + // Emit the properties section. + writePropertiesSection(emitter); + // Write the generated bytecode to the provided output stream. emitter.writeTo(os); + + return success(); } //===----------------------------------------------------------------------===// @@ -656,7 +789,8 @@ //===----------------------------------------------------------------------===// // Operations -void BytecodeWriter::writeBlock(EncodingEmitter &emitter, Block *block) { +LogicalResult BytecodeWriter::writeBlock(EncodingEmitter &emitter, + Block *block) { ArrayRef args = block->getArguments(); bool hasArgs = !args.empty(); @@ -685,10 +819,12 @@ // Emit the operations within the block. for (Operation &op : *block) - writeOp(emitter, &op); + if (failed(writeOp(emitter, &op))) + return failure(); + return success(); } -void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) { +LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) { emitter.emitVarInt(numberingState.getNumber(op->getName())); // Emit a mask for the operation components. We need to fill this in later @@ -702,10 +838,27 @@ emitter.emitVarInt(numberingState.getNumber(op->getLoc())); // Emit the attributes of this operation. - DictionaryAttr attrs = op->getAttrDictionary(); + DictionaryAttr attrs = op->getDiscardableAttrDictionary(); + // Temporarily allow deployment to version <4 ; this won't be possible in the + // future when we remove getAttrDictionary(). + if (config.bytecodeVersion < 4) + attrs = op->getAttrDictionary(); if (!attrs.empty()) { opEncodingMask |= bytecode::OpEncodingMask::kHasAttrs; - emitter.emitVarInt(numberingState.getNumber(op->getAttrDictionary())); + emitter.emitVarInt(numberingState.getNumber(attrs)); + } + + // Emit the properties of this operation, for now we still support deployment + // to version <4. + if (config.bytecodeVersion >= 4) { + ssize_t propertiesId = propertiesSection.emit(op); + if (propertiesId != -1) { + if (config.bytecodeVersion < 4) + return op->emitOpError( + "cannot emit properties before bytecode version 4"); + opEncodingMask |= bytecode::OpEncodingMask::kHasProperties; + emitter.emitVarInt(propertiesId); + } } // Emit the result types of the operation. @@ -757,15 +910,18 @@ // If the region is not isolated from above, or we are emitting bytecode // targetting version <2, we don't use a section. if (!isIsolatedFromAbove || config.bytecodeVersion < 2) { - writeRegion(emitter, ®ion); + if (failed(writeRegion(emitter, ®ion))) + return failure(); continue; } EncodingEmitter regionEmitter; - writeRegion(regionEmitter, ®ion); + if (failed(writeRegion(regionEmitter, ®ion))) + return failure(); emitter.emitSection(bytecode::Section::kIR, std::move(regionEmitter)); } } + return success(); } void BytecodeWriter::writeUseListOrders(EncodingEmitter &emitter, @@ -856,11 +1012,14 @@ } } -void BytecodeWriter::writeRegion(EncodingEmitter &emitter, Region *region) { +LogicalResult BytecodeWriter::writeRegion(EncodingEmitter &emitter, + Region *region) { // If the region is empty, we only need to emit the number of blocks (which is // zero). - if (region->empty()) - return emitter.emitVarInt(/*numBlocks*/ 0); + if (region->empty()) { + emitter.emitVarInt(/*numBlocks*/ 0); + return success(); + } // Emit the number of blocks and values within the region. unsigned numBlocks, numValues; @@ -870,10 +1029,13 @@ // Emit the blocks within the region. for (Block &block : *region) - writeBlock(emitter, &block); + if (failed(writeBlock(emitter, &block))) + return failure(); + return success(); } -void BytecodeWriter::writeIRSection(EncodingEmitter &emitter, Operation *op) { +LogicalResult BytecodeWriter::writeIRSection(EncodingEmitter &emitter, + Operation *op) { EncodingEmitter irEmitter; // Write the IR section the same way as a block with no arguments. Note that @@ -882,9 +1044,11 @@ irEmitter.emitVarIntWithFlag(/*numOps*/ 1, /*hasArgs*/ false); // Emit the operations. - writeOp(irEmitter, op); + if (failed(writeOp(irEmitter, op))) + return failure(); emitter.emitSection(bytecode::Section::kIR, std::move(irEmitter)); + return success(); } //===----------------------------------------------------------------------===// @@ -1000,14 +1164,22 @@ emitter.emitSection(bytecode::Section::kString, std::move(stringEmitter)); } +//===----------------------------------------------------------------------===// +// Properties + +void BytecodeWriter::writePropertiesSection(EncodingEmitter &emitter) { + EncodingEmitter propertiesEmitter; + propertiesSection.write(propertiesEmitter); + emitter.emitSection(bytecode::Section::kProperties, + std::move(propertiesEmitter)); +} + //===----------------------------------------------------------------------===// // Entry Points //===----------------------------------------------------------------------===// LogicalResult mlir::writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config) { - BytecodeWriter writer(op, config.getImpl()); - writer.write(op, os); - // Currently there is no failure case. - return success(); + BytecodeWriter writer(op, config); + return writer.write(op, os); } diff --git a/mlir/lib/Bytecode/Writer/CMakeLists.txt b/mlir/lib/Bytecode/Writer/CMakeLists.txt --- a/mlir/lib/Bytecode/Writer/CMakeLists.txt +++ b/mlir/lib/Bytecode/Writer/CMakeLists.txt @@ -8,4 +8,5 @@ LINK_LIBS PUBLIC MLIRIR MLIRSupport + MLIRBytecodeOpInterface ) diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h --- a/mlir/lib/Bytecode/Writer/IRNumbering.h +++ b/mlir/lib/Bytecode/Writer/IRNumbering.h @@ -18,6 +18,8 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringMap.h" +#include "llvm/CodeGen/NonRelocatableStringpool.h" +#include namespace mlir { class BytecodeDialectInterface; @@ -133,7 +135,7 @@ /// emission. class IRNumberingState { public: - IRNumberingState(Operation *op); + IRNumberingState(Operation *op, const BytecodeWriterConfig &config); /// Return the numbered dialects. auto getDialects() { @@ -241,6 +243,9 @@ /// The next value ID to assign when numbering. unsigned nextValueID = 0; + + // Configuration: useful to query the required version to emit. + const BytecodeWriterConfig &config; }; } // namespace detail } // namespace bytecode diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -8,6 +8,7 @@ #include "IRNumbering.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" @@ -24,6 +25,10 @@ NumberingDialectWriter(IRNumberingState &state) : state(state) {} void writeAttribute(Attribute attr) override { state.number(attr); } + void writeOptionalAttribute(Attribute attr) override { + if (attr) + state.number(attr); + } void writeType(Type type) override { state.number(type); } void writeResourceHandle(const AsmDialectResourceHandle &resource) override { state.number(resource.getDialect(), resource); @@ -106,7 +111,9 @@ value->number = idx; } -IRNumberingState::IRNumberingState(Operation *op) { +IRNumberingState::IRNumberingState(Operation *op, + const BytecodeWriterConfig &config) + : config(config) { // Compute a global operation ID numbering according to the pre-order walk of // the IR. This is used as reference to construct use-list orders. unsigned operationID = 0; @@ -276,10 +283,28 @@ } // Only number the operation's dictionary if it isn't empty. - DictionaryAttr dictAttr = op.getAttrDictionary(); + DictionaryAttr dictAttr = op.getDiscardableAttrDictionary(); + if (config.getDesiredBytecodeVersion() < 4) + dictAttr = op.getAttrDictionary(); if (!dictAttr.empty()) number(dictAttr); + // Visit the operation properties (if any) to make sure referenced attributes + // are numbered. + if (config.getDesiredBytecodeVersion() >= 4 && + op.getPropertiesStorageSize()) { + if (op.isRegistered()) { + auto iface = cast(op); + NumberingDialectWriter writer(*this); + iface.writeProperties(writer); + } else { + // Unregistered op are storing properties as an optional attribute. + Attribute prop = *op.getPropertiesStorage().as(); + if (prop) + number(prop); + } + } + number(op.getLoc()); } diff --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir --- a/mlir/test/Bytecode/invalid/invalid-structure.mlir +++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir @@ -9,7 +9,7 @@ //===--------------------------------------------------------------------===// // RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION -// VERSION: bytecode version 127 is newer than the current version 3 +// VERSION: bytecode version 127 is newer than the current version {{[0-9]+}} //===--------------------------------------------------------------------===// // Producer diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.h @@ -14,6 +14,7 @@ #ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_H #define MLIR_TESTTRANSFORMDIALECTEXTENSION_H +#include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1103,6 +1103,21 @@ "getDiag")) ->body(); + auto &readPropertiesMethod = + opClass + .addStaticMethod( + "::mlir::LogicalResult", "readProperties", + MethodParameter("::mlir::DialectBytecodeReader &", "reader"), + MethodParameter("::mlir::OperationState &", "state")) + ->body(); + + auto &writePropertiesMethod = + opClass + .addMethod( + "void", "writeProperties", + MethodParameter("::mlir::DialectBytecodeWriter &", "writer")) + ->body(); + opClass.declare("Properties", "FoldAdaptor::Properties"); // Convert the property to the attribute form. @@ -1304,6 +1319,40 @@ } } verifyInherentAttrsMethod << " return ::mlir::success();"; + + // Populate bytecode serialization logic. + readPropertiesMethod + << " auto &prop = state.getOrAddProperties(); (void)prop;"; + writePropertiesMethod << " auto &prop = getProperties(); (void)prop;\n"; + for (const auto &attrOrProp : attrOrProperties) { + if (const auto *namedProperty = + attrOrProp.dyn_cast()) { + // TODO + continue; + } + const auto *namedAttr = attrOrProp.dyn_cast(); + StringRef name = namedAttr->attrName; + if (namedAttr->isRequired) { + readPropertiesMethod << formatv(R"( + if (failed(reader.readAttribute(prop.{0}))) + return failure(); +)", + name); + writePropertiesMethod + << formatv(" writer.writeAttribute(prop.{0});\n", name); + } else { + readPropertiesMethod << formatv(R"( + if (failed(reader.readOptionalAttribute(prop.{0}))) + return failure(); +)", + name); + writePropertiesMethod << formatv(R"( + writer.writeOptionalAttribute(prop.{0}); +)", + name); + } + } + readPropertiesMethod << " return success();"; } void OpEmitter::genAttrGetters() { @@ -3300,6 +3349,9 @@ // native/interface traits and after all the traits with `StructuralOpTrait`. opClass.addTrait("::mlir::OpTrait::OpInvariants"); + if (emitHelper.hasProperties()) + opClass.addTrait("::mlir::BytecodeOpInterface::Trait"); + // Add the native and interface traits. for (const auto &trait : op.getTraits()) { if (auto *opTrait = dyn_cast(&trait)) {