diff --git a/mlir/include/mlir/Support/EndianUtilities.h b/mlir/include/mlir/Support/EndianUtilities.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Support/EndianUtilities.h @@ -0,0 +1,24 @@ +//===- EndianUtilities.h - utilities for endian conversion ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Common utilities for endian conversion. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SUPPORT_ENDIANUTILITIES_H_ +#define MLIR_SUPPORT_ENDIANUTILITIES_H_ + +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +/// Convert endianess of raw data for big-endian(BE) machines. +/// `inRawData` LE -> `outRawData` BE, `inRawData` BE -> `outRawData` LE +void convEndianBE(ArrayRef inRawData, ArrayRef outRawData, + ShapedType type); +} // namespace mlir +#endif // MLIR_SUPPORT_ENDIANUTILITIES_H_ diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/Support/EndianUtilities.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" @@ -1505,7 +1506,14 @@ auto numElements = type.getNumElements(); if (!attr.isSplat() && allowHex && shouldPrintElementsAttrWithHex(numElements)) { - ArrayRef rawData = attr.getRawData(); + ArrayRef inRawData = attr.getRawData(); + // Convert endianess in big-endian(BE) machines. `inRawData` is BE in BE + // machines. It is converted here to print in LE format. Not converted + // in little-endian(LE) machines. + SmallVector outDataVec; + ArrayRef rawData = + llvm::makeArrayRef((char *)outDataVec.data(), inRawData.size()); + convEndianBE(inRawData, rawData, type); os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\""; return; diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -15,7 +15,9 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/Support/EndianUtilities.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Endian.h" using namespace mlir; using namespace mlir::detail; @@ -694,7 +696,16 @@ if (parseElementAttrHexValues(p, hexStorage.getValue(), data)) return nullptr; - ArrayRef rawData(data.data(), data.size()); + // Convert endianess in big-endian(BE) machines. `inRawData` is + // little-endian(LE) because HEX in raw data of dense element attribute + // is always LE format. It is converted into BE here to be used in BE + // machines. Not converted in LE machines. + ArrayRef inRawData(data.data(), data.size()); + SmallVector outDataVec; + ArrayRef rawData = + llvm::makeArrayRef((char *)outDataVec.data(), inRawData.size()); + convEndianBE(inRawData, rawData, type); + bool detectedSplat = false; if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) { p.emitError(loc) << "elements hex data size is invalid for provided type: " diff --git a/mlir/lib/Support/CMakeLists.txt b/mlir/lib/Support/CMakeLists.txt --- a/mlir/lib/Support/CMakeLists.txt +++ b/mlir/lib/Support/CMakeLists.txt @@ -4,12 +4,14 @@ MlirOptMain.cpp StorageUniquer.cpp ToolUtilities.cpp + EndianUtilities.cpp ) add_mlir_library(MLIRSupport FileUtilities.cpp StorageUniquer.cpp ToolUtilities.cpp + EndianUtilities.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Support diff --git a/mlir/lib/Support/EndianUtilities.cpp b/mlir/lib/Support/EndianUtilities.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Support/EndianUtilities.cpp @@ -0,0 +1,76 @@ +//===- EndianUtilities.cpp - utilities for endian conversion---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Definitions of common utilities for endian conversion. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/EndianUtilities.h" +#include "llvm/Support/Endian.h" + +using namespace mlir; +using llvm::support::ulittle16_t; +using llvm::support::ulittle32_t; +using llvm::support::ulittle64_t; + +/// Return the bit width which DenseElementsAttr should use for this type. +/// This is the same with `getDenseElementBitWidth` in +/// `mlir/lib/IR/AttributeDetail.h`. +static inline size_t getDenseElementBitWidth(Type eltType) { + // Align the width for complex to 8 to make storage and interpretation easier. + if (ComplexType comp = eltType.dyn_cast()) + return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2; + if (eltType.isIndex()) + return IndexType::kInternalStorageBitWidth; + return eltType.getIntOrFloatBitWidth(); +} + +/// Get the bitwidth of a dense element type within the buffer. +/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. +/// These are the same with `getDenseElementStorageWidth` in +/// `mlir/lib/IR/Attributes.cpp` +static size_t getDenseElementStorageWidth(size_t origWidth) { + return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); +} +static size_t getDenseElementStorageWidth(Type elementType) { + return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); +} + +void mlir::convEndianBE(ArrayRef inRawData, ArrayRef outRawData, + ShapedType type) { + uint64_t numElements = type.getNumElements(); + Type elementType = type.getElementType(); + if (ComplexType complexTy = elementType.dyn_cast()) { + elementType = complexTy.getElementType(); + numElements = numElements * 2; + } + size_t storageBitWidth = getDenseElementStorageWidth(elementType); + if (storageBitWidth == 16) { + ulittle16_t *inRawDataPos = const_cast( + reinterpret_cast(inRawData.begin())); + uint16_t *outDataPos = const_cast( + reinterpret_cast(outRawData.begin())); + for (unsigned i = 0, e = numElements; i < e; ++i) + std::copy_n(inRawDataPos + i, 1, outDataPos + i); + } else if (storageBitWidth == 32) { + ulittle32_t *inRawDataPos = const_cast( + reinterpret_cast(inRawData.begin())); + uint32_t *outDataPos = const_cast( + reinterpret_cast(outRawData.begin())); + for (unsigned i = 0, e = numElements; i < e; ++i) + std::copy_n(inRawDataPos + i, 1, outDataPos + i); + } else if (storageBitWidth == 64) { + ulittle64_t *inRawDataPos = const_cast( + reinterpret_cast(inRawData.begin())); + uint64_t *outDataPos = const_cast( + reinterpret_cast(outRawData.begin())); + for (unsigned i = 0, e = numElements; i < e; ++i) + std::copy_n(inRawDataPos + i, 1, outDataPos + i); + } + return; +} diff --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir --- a/mlir/test/IR/dense-elements-hex.mlir +++ b/mlir/test/IR/dense-elements-hex.mlir @@ -1,12 +1,21 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -verify-diagnostics -split-input-file -mlir-print-elementsattrs-with-hex-if-larger=1 | FileCheck %s --check-prefix=HEX // RUN: mlir-opt -allow-unregistered-dialect %s -verify-diagnostics -split-input-file | FileCheck %s +// HEX: dense<"0x000020410000A040"> : tensor<2xf32> +"foo.op"() {dense.attr = dense<[10.0, 5.0]> : tensor<2xf32>} : () -> () + // HEX: dense<"0x00000000000024400000000000001440"> : tensor<2xf64> "foo.op"() {dense.attr = dense<[10.0, 5.0]> : tensor<2xf64>} : () -> () +// CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xf32> +"foo.op"() {dense.attr = dense<"0x000020410000A040"> : tensor<2xf32>} : () -> () + // CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xf64> "foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xf64>} : () -> () +// CHECK: dense<(1.000000e+01,5.000000e+00)> : tensor<2xcomplex> +"foo.op"() {dense.attr = dense<"0x000020410000A040000020410000A040"> : tensor<2xcomplex>} : () -> () + // CHECK: dense<(1.000000e+01,5.000000e+00)> : tensor<2xcomplex> "foo.op"() {dense.attr = dense<"0x0000000000002440000000000000144000000000000024400000000000001440"> : tensor<2xcomplex>} : () -> ()