diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h b/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h @@ -0,0 +1,31 @@ +//===- NarrowTypeEmulationConverter.h - Type Converter for NTE -----*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_ +#define MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir::arith { +/// Converts narrow integer or float types that are not supported +/// by the target hardware to wider types. Currently, we only +/// handle power-of-two integer types and convert them to wider +/// integers that are equal or larger than 8 bits. +class NarrowTypeEmulationConverter : public TypeConverter { +public: + explicit NarrowTypeEmulationConverter(unsigned targetBitwidth); + + unsigned getLoadStoreBitwidth() const { return loadStoreBitwidth; } + +private: + unsigned loadStoreBitwidth; +}; +} // namespace mlir::arith + +#endif // MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_ diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -22,6 +22,7 @@ #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" class WideIntEmulationConverter; +class NarrowTypeEmulationConverter; /// Create a pass to bufferize Arith ops. std::unique_ptr createArithBufferizePass(); @@ -35,6 +36,12 @@ void populateArithWideIntEmulationPatterns( WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns); +/// Adds patterns to emulate narrow Arith and Function ops into wide +/// supported types. Users need to add conversions about the computation +/// domain of narrow types. +void populateArithNarrowTypeEmulationPatterns( + NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns); + /// Add patterns to expand Arith ceil/floor division ops. void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -25,6 +25,7 @@ namespace arith { class WideIntEmulationConverter; +class NarrowTypeEmulationConverter; } // namespace arith namespace memref { @@ -73,6 +74,17 @@ void populateMemRefWideIntEmulationConversions( arith::WideIntEmulationConverter &typeConverter); +/// Appends patterns for emulating memref operations over narrow types with ops +/// over wider types. +void populateMemRefNarrowTypeEmulationPatterns( + arith::NarrowTypeEmulationConverter &typeConverter, + RewritePatternSet &patterns); + +/// Appends type conversions for emulating memref operations over narrow types +/// with ops over wider types. +void populateMemRefNarrowTypeEmulationConversions( + arith::NarrowTypeEmulationConverter &typeConverter); + /// Transformation to do multi-buffering/array expansion to remove dependencies /// on the temporary allocation between consecutive loop iterations. /// It returns the new allocation if the original allocation was multi-buffered diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp EmulateWideInt.cpp + EmulateNarrowType.cpp ExpandOps.cpp IntNarrowing.cpp IntRangeOptimizations.cpp diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp @@ -0,0 +1,61 @@ +//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APInt.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Public Interface Definition +//===----------------------------------------------------------------------===// + +arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter( + unsigned targetBitwidth) + : loadStoreBitwidth(targetBitwidth) { + assert(llvm::isPowerOf2_32(targetBitwidth) && + "Only power-of-two integers are supported"); + + // Allow unknown types. + addConversion([](Type ty) -> std::optional { return ty; }); + + // Function case. + addConversion([this](FunctionType ty) -> std::optional { + SmallVector inputs; + if (failed(convertTypes(ty.getInputs(), inputs))) + return std::nullopt; + + SmallVector results; + if (failed(convertTypes(ty.getResults(), results))) + return std::nullopt; + + return FunctionType::get(ty.getContext(), inputs, results); + }); +} + +void arith::populateArithNarrowTypeEmulationPatterns( + NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns) { + // Populate `func.*` conversion patterns. + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ ExpandOps.cpp ExpandStridedMetadata.cpp EmulateWideInt.cpp + EmulateNarrowType.cpp ExtractAddressComputations.cpp FoldMemRefAliasOps.cpp IndependenceTransforms.cpp diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -0,0 +1,247 @@ +//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +/// When data is loaded/stored in `targetBits` granularity, but is used in +/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is +/// treated as an array of elements of width `sourceBits`. +/// Return the bit offset of the value at position `srcIdx`. For example, if +/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is +/// located at (x % 2) * 4. Because there are two elements in one i8, and one +/// element has 4 bits. +static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, + int targetBits, OpBuilder &builder) { + assert(targetBits % sourceBits == 0); + IntegerType targetType = builder.getIntegerType(targetBits); + IntegerAttr idxAttr = + builder.getIntegerAttr(targetType, targetBits / sourceBits); + auto idx = builder.create(loc, targetType, idxAttr); + IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); + auto srcBitsValue = + builder.create(loc, targetType, srcBitsAttr); + auto m = builder.create(loc, srcIdx, idx); + return builder.create(loc, targetType, m, srcBitsValue); +} + +namespace { + +//===----------------------------------------------------------------------===// +// ConvertMemRefAlloc +//===----------------------------------------------------------------------===// + +struct ConvertMemRefAlloc final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) { + return rewriter.notifyMatchFailure( + op->getLoc(), + llvm::formatv("failed to convert memref type: {0}", op.getType())); + } + + rewriter.replaceOpWithNewOp( + op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(), + adaptor.getAlignmentAttr()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertMemRefLoad +//===----------------------------------------------------------------------===// + +struct ConvertMemRefLoad final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getMemRefType()); + if (!newTy) { + return rewriter.notifyMatchFailure( + op->getLoc(), llvm::formatv("failed to convert memref type: {0}", + op.getMemRefType())); + } + + if (op.getMemRefType() == newTy) + return failure(); + + auto loc = op.getLoc(); + auto sourceType = cast(adaptor.getMemref().getType()); + unsigned sourceRank = sourceType.getRank(); + + auto srcElementType = sourceType.getElementType(); + auto oldElementType = op.getMemRefType().getElementType(); + int srcBits = oldElementType.getIntOrFloatBitWidth(); + int dstBits = srcElementType.getIntOrFloatBitWidth(); + if (dstBits % srcBits != 0) + return rewriter.notifyMatchFailure( + op, "only dstBits % srcBits == 0 supported"); + + // The emulation only works on 1D memref types. + // To make this work on N-D memref, we need to linearize the offset. + // + // For example, to emulate i4 to i8, the following op: + // + // %0 = memref.load %arg0[%v0, %v1] : + // memref> + // + // can be replaced with + // + // %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 + // + // %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1 + // %scaled_linear_offset = %linearized_offset / 8 * 4 + // %linearized_size = %size0 * %size1 + // + // %linearized = memref.reinterpret_cast %b, offset = [%offset], + // sizes = [%linearized_size], strides = [%stride#1] + // + // %new_load = memref.load %linearized[%scaled_linear_offset] : + // memref> + auto stridedMetadata = rewriter.create( + loc, adaptor.getMemref()); + auto baseBuffer = stridedMetadata.getBaseBuffer(); + auto baseSizes = stridedMetadata.getSizes(); + auto baseStrides = stridedMetadata.getStrides(); + auto baseOffset = stridedMetadata.getOffset(); + + SmallVector indices = adaptor.getIndices(); + assert(indices.size() == baseStrides.size()); + assert(indices.size() == sourceRank); + + IndexType targetType = rewriter.getIndexType(); + IntegerAttr attr = rewriter.getIndexAttr(dstBits / srcBits); + auto scaler = rewriter.create(loc, targetType, attr); + + // Linearize offset and sizes. + SmallVector adjustOffsets; + for (unsigned i = 0; i < sourceRank; ++i) { + adjustOffsets.push_back( + rewriter.create(loc, indices[i], baseStrides[i])); + } + + Value linearizedOffset = adjustOffsets[0]; + Value linearizedSize = baseSizes[0]; + if (sourceRank == 1) { + linearizedOffset = + rewriter.create(loc, linearizedOffset, scaler); + } + + for (unsigned i = 1; i < sourceRank; ++i) { + linearizedOffset = rewriter.create(loc, linearizedOffset, + adjustOffsets[i]); + linearizedOffset = + rewriter.create(loc, linearizedOffset, scaler); + linearizedSize = + rewriter.create(loc, linearizedSize, baseSizes[i]); + } + + // Flatten n-D MemRef to 1-D MemRef. + auto layoutAttr = StridedLayoutAttr::get( + sourceType.getContext(), ShapedType::kDynamic, {ShapedType::kDynamic}); + int64_t staticShape = sourceType.hasStaticShape() + ? sourceType.getNumElements() + : ShapedType::kDynamic; + auto flattenMemrefType = MemRefType::get( + staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace()); + + auto reinterpret = rewriter.create( + loc, flattenMemrefType, baseBuffer, baseOffset, linearizedSize, + baseStrides.back()); + + auto newLoad = rewriter.create( + loc, srcElementType, reinterpret.getResult(), linearizedOffset, + op.getNontemporal()); + + // Get the offset and shift the bits to the rightmost. + // Note, currently only the big-endian is supported. + auto lastIdx = rewriter.create( + loc, srcElementType, adaptor.getIndices().back()); + Value BitwidthOffset = + getOffsetForBitwidth(loc, lastIdx, srcBits, dstBits, rewriter); + auto bitsLoad = + rewriter.create(loc, newLoad, BitwidthOffset); + + // Get the corresponding bits. If the arith computation bitwidth equals + // to the emulated bitwidth, we apply a mask to extract the low bits. + // Otherwise, we truncate the result. + Operation *result; + auto resultTy = getTypeConverter()->convertType(oldElementType); + if (resultTy == srcElementType) { + auto mask = rewriter.create( + loc, srcElementType, + rewriter.getIntegerAttr(srcElementType, (1 << srcBits) - 1)); + + result = rewriter.create(loc, bitsLoad, mask); + } else { + result = rewriter.create(loc, resultTy, bitsLoad); + } + + rewriter.replaceOp(op, result->getResult(0)); + return success(); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Public Interface Definition +//===----------------------------------------------------------------------===// + +void memref::populateMemRefNarrowTypeEmulationPatterns( + arith::NarrowTypeEmulationConverter &typeConverter, + RewritePatternSet &patterns) { + + // Populate `memref.*` conversion patterns. + patterns.add(typeConverter, + patterns.getContext()); +} + +void memref::populateMemRefNarrowTypeEmulationConversions( + arith::NarrowTypeEmulationConverter &typeConverter) { + typeConverter.addConversion( + [&typeConverter](MemRefType ty) -> std::optional { + auto intTy = dyn_cast(ty.getElementType()); + if (!intTy) + return ty; + + unsigned width = intTy.getWidth(); + unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth(); + if (width >= loadStoreWidth) + return ty; + + auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth, + intTy.getSignedness()); + if (!newElemTy) + return std::nullopt; + + return ty.cloneWith(std::nullopt, newElemTy); + }); +} diff --git a/mlir/test/Dialect/Arith/emulate-narrow-type.mlir b/mlir/test/Dialect/Arith/emulate-narrow-type.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arith/emulate-narrow-type.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8" %s | FileCheck %s + +// Expect no conversions, f32 is not an integer type. +// CHECK-LABEL: func @identity_f32 +// CHECK-SAME: ([[ARG:%.+]]: f32) -> f32 +// CHECK-NEXT: return [[ARG]] : f32 +func.func @identity_f32(%a : f32) -> f32 { + return %a : f32 +} + +// Expect no conversions, i32 is supported. +// CHECK-LABEL: func @identity_i32 +// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: return [[ARG]] : vector<2xi32> +func.func @identity_i32(%a : vector<2xi32>) -> vector<2xi32> { + return %a : vector<2xi32> +} + +// CHECK-LABEL: func @identity_scalar +// CHECK-SAME: ([[ARG:%.+]]: i8) -> i8 +// CHECK-NEXT: return [[ARG]] : i8 +func.func @identity_scalar(%x : i4) -> i4 { + return %x : i4 +} + +// CHECK-LABEL: func @identity_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8> +// CHECK-NEXT: return [[ARG]] : vector<4xi8> +func.func @identity_vector(%x : vector<4xi4>) -> vector<4xi4> { + return %x : vector<4xi4> +} + +// CHECK-LABEL: func @identity_vector2d +// CHECK-SAME: ([[ARG:%.+]]: vector<3x4xi8>) -> vector<3x4xi8> +// CHECK-NEXT: return [[ARG]] : vector<3x4xi8> +func.func @identity_vector2d(%x : vector<3x4xi4>) -> vector<3x4xi4> { + return %x : vector<3x4xi4> +} + +// CHECK-LABEL: func @call +// CHECK-SAME: ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8> +// CHECK-NEXT: [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<4xi8>) -> vector<4xi8> +// CHECK-NEXT: return [[RES]] : vector<4xi8> +func.func @call(%a : vector<4xi4>) -> vector<4xi4> { + %res = func.call @identity_vector(%a) : (vector<4xi4>) -> vector<4xi4> + return %res : vector<4xi4> +} diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -0,0 +1,81 @@ +// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=4 memref-load-bitwidth=8" %s | FileCheck %s + +// Expect no conversions, i32 is supported. +// CHECK-LABEL: func @memref_i32 +// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi32, 1> +// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi32, 1> +// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi32, 1> +// CHECK-NEXT: return +func.func @memref_i32() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : i32 + %m = memref.alloc() : memref<4xi32, 1> + %v = memref.load %m[%c0] : memref<4xi32, 1> + memref.store %c1, %m[%c0] : memref<4xi32, 1> + return +} + +// Expect no conversions, f32 is not an integer type. +// CHECK-LABEL: func @memref_f32 +// CHECK: [[M:%.+]] = memref.alloc() : memref<4xf32, 1> +// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xf32, 1> +// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xf32, 1> +// CHECK-NEXT: return +func.func @memref_f32() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1.0 : f32 + %m = memref.alloc() : memref<4xf32, 1> + %v = memref.load %m[%c0] : memref<4xf32, 1> + memref.store %c1, %m[%c0] : memref<4xf32, 1> + return +} + +// CHECK-LABEL: func @memref_load_i4 +// CHECK-SAME: (%[[ARG:.*]]: index) +// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4xi8> +// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref, index, index, index +// CHECK-NEXT: %[[CI:.*]] = arith.constant 2 : index +// CHECK-NEXT: %[[LINEAR:.*]] = arith.muli %[[ARG]], %[[STRIDES]] : index +// CHECK-NEXT: %[[INDEX:.*]] = arith.divui %[[LINEAR]], %[[CI]] : index +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref to memref<4xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8 +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 +// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 +// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 +// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 +// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4 +// CHECK-NEXT: return +func.func @memref_load_i4(%arg0: index) { + %0 = memref.alloc() : memref<4xi4> + %1 = memref.load %0[%arg0] : memref<4xi4> + return +} + + +// CHECK-LABEL: func @memref_load_i4_rank2 +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) +// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4x4xi8> +// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[M]] : memref<4x4xi8> -> memref, index, index, index, index, index +// CHECK-NEXT: %[[CI:.*]] = arith.constant 2 : index +// CHECK-NEXT: %[[IDX1:.*]] = arith.muli %[[ARG0]], %[[STRIDES]]#0 : index +// CHECK-NEXT: %[[IDX2:.*]] = arith.muli %[[ARG1]], %[[STRIDES]]#1 : index +// CHECK-NEXT: %[[LINEAR:.*]] = arith.addi %[[IDX1]], %[[IDX2]] : index +// CHECK-NEXT: %[[INDEX:.*]] = arith.divui %[[LINEAR]], %[[CI]] : index +// CHECK-NEXT: %[[LSIZE:.*]] = arith.muli %[[SIZES]]#0, %[[SIZES]]#1 : index +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref to memref<16xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<16xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8 +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 +// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 +// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 +// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 +// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4 +// CHECK-NEXT: return +func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) { + %0 = memref.alloc() : memref<4x4xi4> + %1 = memref.load %0[%arg0,%arg1] : memref<4x4xi4> + return +} diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt --- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt +++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRMemRefTestPasses TestComposeSubView.cpp + TestEmulateNarrowType.cpp TestMultiBuffer.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp @@ -0,0 +1,115 @@ +//===- TestEmulateNarrowType.cpp - Test Narrow Type Emulation ------*- c++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +struct TestEmulateNarrowTypePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateNarrowTypePass) + + TestEmulateNarrowTypePass() = default; + TestEmulateNarrowTypePass(const TestEmulateNarrowTypePass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { return "test-emulate-narrow-int"; } + StringRef getDescription() const final { + return "Function pass to test Narrow Integer Emulation"; + } + + void runOnOperation() override { + if (!llvm::isPowerOf2_32(loadStoreEmulateBitwidth) || + loadStoreEmulateBitwidth < 8) { + signalPassFailure(); + return; + } + + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + + arith::NarrowTypeEmulationConverter typeConverter(loadStoreEmulateBitwidth); + + // Convert scalar type. + typeConverter.addConversion([this](IntegerType ty) -> std::optional { + unsigned width = ty.getWidth(); + if (width >= arithComputeBitwidth) + return ty; + + return IntegerType::get(ty.getContext(), arithComputeBitwidth); + }); + + // Convert vector type. + typeConverter.addConversion([this](VectorType ty) -> std::optional { + auto intTy = dyn_cast(ty.getElementType()); + if (!intTy) + return ty; + + unsigned width = intTy.getWidth(); + if (width >= arithComputeBitwidth) + return ty; + + return VectorType::get( + to_vector(ty.getShape()), + IntegerType::get(ty.getContext(), arithComputeBitwidth)); + }); + + memref::populateMemRefNarrowTypeEmulationConversions(typeConverter); + ConversionTarget target(*ctx); + target.addDynamicallyLegalOp([&typeConverter](Operation *op) { + return typeConverter.isLegal(cast(op).getFunctionType()); + }); + auto opLegalCallback = [&typeConverter](Operation *op) { + return typeConverter.isLegal(op); + }; + target.addDynamicallyLegalOp(opLegalCallback); + target.addDynamicallyLegalDialect< + arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect>( + [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); + + RewritePatternSet patterns(ctx); + + arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns); + memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); + } + + Option loadStoreEmulateBitwidth{ + *this, "memref-load-bitwidth", + llvm::cl::desc("memref load/store emulation bit width"), + llvm::cl::init(8)}; + + Option arithComputeBitwidth{ + *this, "arith-compute-bitwidth", + llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)}; +}; +} // namespace + +namespace mlir::test { +void registerTestEmulateNarrowTypePass() { + PassRegistration(); +} +} // namespace mlir::test diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -86,6 +86,7 @@ void registerTestDialectConversionPasses(); void registerTestDominancePass(); void registerTestDynamicPipelinePass(); +void registerTestEmulateNarrowTypePass(); void registerTestExpandMathPass(); void registerTestFooAnalysisPass(); void registerTestComposeSubView(); @@ -204,6 +205,7 @@ mlir::test::registerTestDeadCodeAnalysisPass(); mlir::test::registerTestDominancePass(); mlir::test::registerTestDynamicPipelinePass(); + mlir::test::registerTestEmulateNarrowTypePass(); mlir::test::registerTestExpandMathPass(); mlir::test::registerTestFooAnalysisPass(); mlir::test::registerTestComposeSubView();