diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/NarrowIntEmulationConverter.h b/mlir/include/mlir/Dialect/Arith/Transforms/NarrowIntEmulationConverter.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/Transforms/NarrowIntEmulationConverter.h @@ -0,0 +1,26 @@ +//===- NarrowIntEmulationConverter.h - Type Converter for NIE -----*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITH_NARROW_INT_EMULATION_CONVERTER_H_ +#define MLIR_DIALECT_ARITH_NARROW_INT_EMULATION_CONVERTER_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir::arith { +/// Converts integer types that are too narrow and not supported by the target +/// hardware. Currently, we only handle power-of-two integer types and convert +/// them to wider integers that are equal or larger than 8 bits. +class NarrowIntEmulationConverter : public TypeConverter { +public: + explicit NarrowIntEmulationConverter(unsigned targetWideInt); + unsigned targetBitwidth; +}; +} // namespace mlir::arith + +#endif // MLIR_DIALECT_ARITH_NARROW_INT_EMULATION_CONVERTER_H_ diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -22,6 +22,7 @@ #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" class WideIntEmulationConverter; +class NarrowIntEmulationConverter; /// Create a pass to bufferize Arith ops. std::unique_ptr createArithBufferizePass(); @@ -35,6 +36,11 @@ void populateArithWideIntEmulationPatterns( WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns); +/// Adds patterns to emulate narrow Arith and Function ops +/// into wide supported integer types. +void populateArithNarrowIntEmulationPatterns( + NarrowIntEmulationConverter &typeConverter, RewritePatternSet &patterns); + /// Add patterns to expand Arith ceil/floor division ops. void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -25,6 +25,7 @@ namespace arith { class WideIntEmulationConverter; +class NarrowIntEmulationConverter; } // namespace arith namespace memref { @@ -73,6 +74,17 @@ void populateMemRefWideIntEmulationConversions( arith::WideIntEmulationConverter &typeConverter); +/// Appends patterns for emulating narrow integer memref operations with ops +/// over wider integer types. +void populateMemRefNarrowIntEmulationPatterns( + arith::NarrowIntEmulationConverter &typeConverter, + RewritePatternSet &patterns); + +/// Appends type converions for emulating narrow integer memref operations with +/// ops over wider integer types. +void populateMemRefNarrowIntEmulationConversions( + arith::NarrowIntEmulationConverter &typeConverter); + /// Transformation to do multi-buffering/array expansion to remove dependencies /// on the temporary allocation between consecutive loop iterations. /// It returns the new allocation if the original allocation was multi-buffered diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp EmulateWideInt.cpp + EmulateNarrowInt.cpp ExpandOps.cpp IntNarrowing.cpp IntRangeOptimizations.cpp diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowInt.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowInt.cpp @@ -0,0 +1,62 @@ +//===- EmulateNarrowInt.cpp - Narrow integer operation emulation ----*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/NarrowIntEmulationConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APInt.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Public Interface Definition +//===----------------------------------------------------------------------===// + +arith::NarrowIntEmulationConverter::NarrowIntEmulationConverter( + unsigned targetWideInt) + : targetBitwidth(targetWideInt) { + assert(llvm::isPowerOf2_32(targetWideInt) && + "Only power-of-two integers are supported"); + + // Allow unknown types. + addConversion([](Type ty) -> std::optional { return ty; }); + + // Function case. + addConversion([this](FunctionType ty) -> std::optional { + SmallVector inputs; + if (failed(convertTypes(ty.getInputs(), inputs))) + return std::nullopt; + + SmallVector results; + if (failed(convertTypes(ty.getResults(), results))) + return std::nullopt; + + return FunctionType::get(ty.getContext(), inputs, results); + }); +} + +void arith::populateArithNarrowIntEmulationPatterns( + NarrowIntEmulationConverter &typeConverter, RewritePatternSet &patterns) { + // Populate `func.*` conversion patterns. + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ ExpandOps.cpp ExpandStridedMetadata.cpp EmulateWideInt.cpp + EmulateNarrowInt.cpp ExtractAddressComputations.cpp FoldMemRefAliasOps.cpp IndependenceTransforms.cpp diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp @@ -0,0 +1,222 @@ +//===- EmulateNarrowInt.cpp - Narrow integer operation emulation ----*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/NarrowIntEmulationConverter.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +/// Returns the offset of the value in `targetBits` representation. +/// +/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. +/// It's assumed to be non-negative. +/// +/// When accessing an element in the array treating as having elements of +/// `targetBits`, multiple values are loaded in the same time. The method +/// returns the offset where the `srcIdx` locates in the value. For example, if +/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is +/// located at (x % 2) * 4. Because there are two elements in one i8, and one +/// element has 4 bits. +static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, + int targetBits, OpBuilder &builder) { + assert(targetBits % sourceBits == 0); + IntegerType targetType = builder.getIntegerType(targetBits); + IntegerAttr idxAttr = + builder.getIntegerAttr(targetType, targetBits / sourceBits); + auto idx = builder.create(loc, targetType, idxAttr); + IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); + auto srcBitsValue = + builder.create(loc, targetType, srcBitsAttr); + auto m = builder.create(loc, srcIdx, idx); + return builder.create(loc, targetType, m, srcBitsValue); +} + +namespace { + +//===----------------------------------------------------------------------===// +// ConvertMemRefAlloc +//===----------------------------------------------------------------------===// + +struct ConvertMemRefAlloc final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return rewriter.notifyMatchFailure( + op->getLoc(), + llvm::formatv("failed to convert memref type: {0}", op.getType())); + + rewriter.replaceOpWithNewOp( + op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(), + adaptor.getAlignmentAttr()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertMemRefLoad +//===----------------------------------------------------------------------===// + +struct ConvertMemRefLoad final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) + return rewriter.notifyMatchFailure( + op->getLoc(), llvm::formatv("failed to convert memref type: {0}", + op.getMemRefType())); + if (op.getMemRefType() == newTy) + return success(); + + auto loc = op.getLoc(); + Value source = adaptor.getMemref(); + auto sourceType = cast(source.getType()); + auto srcElementType = sourceType.getElementType(); + unsigned sourceRank = sourceType.getRank(); + + auto oldElementType = + cast(op.getMemref().getType()).getElementType(); + int srcBits = oldElementType.getIntOrFloatBitWidth(); + int dstBits = srcElementType.getIntOrFloatBitWidth(); + assert(dstBits % srcBits == 0); + + // The emulation only works on 1D memref types. To make this work on N-D + // memref, we need to linearize the offset. + // Specifically, %0 = memref.load %0[%v0][%v1] : + // memref> can be replaced with + // %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 + // %linearized_offset = %v0 * %stride#0 + %scaled_v1 * %stride#1 + // where %scaled_v1 = v1 / targetBits * sourceBits + // %linearized_size = %size0 * %size1 + // %linearized = memref.reinterpret_cast %b, offset = [%offset], sizes = + // [%linearized_size], strides = [%stride#1] %load = memref.load + // %linearized[%linearized_offset] : memref> + auto stridedMetadata = + rewriter.create(loc, source); + auto baseBuffer = stridedMetadata.getBaseBuffer(); + auto baseSizes = stridedMetadata.getSizes(); + auto baseStrides = stridedMetadata.getStrides(); + auto baseOffset = stridedMetadata.getOffset(); + + SmallVector indices = adaptor.getIndices(); + assert(indices.size() == baseStrides.size()); + assert(indices.size() == sourceRank); + + // Only the last index is modified to load the bits needed. + IndexType targetType = rewriter.getIndexType(); + IntegerAttr attr = rewriter.getIndexAttr(dstBits / srcBits); + auto scaler = rewriter.create(loc, targetType, attr); + indices.back() = + rewriter.create(loc, indices.back(), scaler); + + SmallVector adjustOffsets; + for (unsigned i = 0; i < sourceRank; ++i) { + adjustOffsets.push_back( + rewriter.create(loc, indices[i], baseStrides[i])); + } + + // Linearize offset and sizes. + Value linearizedOffset = adjustOffsets[0]; + Value linearizedSize = baseSizes[0]; + for (unsigned i = 1; i < sourceRank; ++i) { + linearizedOffset = rewriter.create(loc, linearizedOffset, + adjustOffsets[i]); + linearizedSize = + rewriter.create(loc, linearizedSize, baseSizes[i]); + } + + // Flatten n-D MemRef to 1-D MemRef. + auto layoutAttr = StridedLayoutAttr::get( + sourceType.getContext(), ShapedType::kDynamic, {ShapedType::kDynamic}); + int64_t staticShape = sourceType.hasStaticShape() + ? sourceType.getNumElements() + : ShapedType::kDynamic; + auto flattenMemrefType = MemRefType::get( + staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace()); + + auto reinterpret = rewriter.create( + loc, flattenMemrefType, baseBuffer, baseOffset, linearizedSize, + baseStrides.back()); + + auto newLoad = rewriter.create( + loc, srcElementType, reinterpret.getResult(), linearizedOffset, + op.getNontemporal()); + + // Get the offset and shift the bits to the rightmost. + auto lastIdx = rewriter.create( + loc, srcElementType, adaptor.getIndices().back()); + Value BitwidthOffset = + getOffsetForBitwidth(loc, lastIdx, srcBits, dstBits, rewriter); + auto bitsLoad = + rewriter.create(loc, newLoad, BitwidthOffset); + + // Get the low bits by truncating the result. + auto result = + rewriter.create(loc, oldElementType, bitsLoad); + rewriter.replaceOp(op, result.getResult()); + + return success(); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Public Interface Definition +//===----------------------------------------------------------------------===// + +void memref::populateMemRefNarrowIntEmulationPatterns( + arith::NarrowIntEmulationConverter &typeConverter, + RewritePatternSet &patterns) { + + // Populate `memref.*` conversion patterns. + patterns.add(typeConverter, + patterns.getContext()); +} + +void memref::populateMemRefNarrowIntEmulationConversions( + arith::NarrowIntEmulationConverter &typeConverter) { + typeConverter.addConversion( + [&typeConverter](MemRefType ty) -> std::optional { + auto intTy = dyn_cast(ty.getElementType()); + if (!intTy) + return ty; + + unsigned width = intTy.getWidth(); + if (width >= typeConverter.targetBitwidth) + return ty; + else { + Type newElemTy = + IntegerType::get(ty.getContext(), typeConverter.targetBitwidth, + intTy.getSignedness()); + if (!newElemTy) + return std::nullopt; + return ty.cloneWith(std::nullopt, newElemTy); + } + }); +} diff --git a/mlir/test/Dialect/Arith/emulate-narrow-int.mlir b/mlir/test/Dialect/Arith/emulate-narrow-int.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arith/emulate-narrow-int.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt --test-emulate-narrow-int="arith-int-bitwidth=8" %s | FileCheck %s + +// Expect no conversions, f32 is not an integer type. +// CHECK-LABEL: func @identity_f32 +// CHECK-SAME: ([[ARG:%.+]]: f32) -> f32 +// CHECK-NEXT: return [[ARG]] : f32 +func.func @identity_f32(%a : f32) -> f32 { + return %a : f32 +} + +// Expect no conversions, i32 is supported. +// CHECK-LABEL: func @identity_i32 +// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: return [[ARG]] : vector<2xi32> +func.func @identity_i32(%a : vector<2xi32>) -> vector<2xi32> { + return %a : vector<2xi32> +} + +// CHECK-LABEL: func @identity_scalar +// CHECK-SAME: ([[ARG:%.+]]: i8) -> i8 +// CHECK-NEXT: return [[ARG]] : i8 +func.func @identity_scalar(%x : i4) -> i4 { + return %x : i4 +} + +// CHECK-LABEL: func @identity_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8> +// CHECK-NEXT: return [[ARG]] : vector<4xi8> +func.func @identity_vector(%x : vector<4xi4>) -> vector<4xi4> { + return %x : vector<4xi4> +} + +// CHECK-LABEL: func @identity_vector2d +// CHECK-SAME: ([[ARG:%.+]]: vector<3x4xi8>) -> vector<3x4xi8> +// CHECK-NEXT: return [[ARG]] : vector<3x4xi8> +func.func @identity_vector2d(%x : vector<3x4xi4>) -> vector<3x4xi4> { + return %x : vector<3x4xi4> +} + +// CHECK-LABEL: func @call +// CHECK-SAME: ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8> +// CHECK-NEXT: [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<4xi8>) -> vector<4xi8> +// CHECK-NEXT: return [[RES]] : vector<4xi8> +func.func @call(%a : vector<4xi4>) -> vector<4xi4> { + %res = func.call @identity_vector(%a) : (vector<4xi4>) -> vector<4xi4> + return %res : vector<4xi4> +} diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-int.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-int.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/emulate-narrow-int.mlir @@ -0,0 +1,81 @@ +// RUN: mlir-opt --test-emulate-narrow-int="arith-int-bitwidth=4 memref-target-bits=8" %s | FileCheck %s + +// Expect no conversions, i32 is supported. +// CHECK-LABEL: func @memref_i32 +// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi32, 1> +// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi32, 1> +// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi32, 1> +// CHECK-NEXT: return +func.func @memref_i32() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : i32 + %m = memref.alloc() : memref<4xi32, 1> + %v = memref.load %m[%c0] : memref<4xi32, 1> + memref.store %c1, %m[%c0] : memref<4xi32, 1> + return +} + +// Expect no conversions, f32 is not an integer type. +// CHECK-LABEL: func @memref_f32 +// CHECK: [[M:%.+]] = memref.alloc() : memref<4xf32, 1> +// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xf32, 1> +// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xf32, 1> +// CHECK-NEXT: return +func.func @memref_f32() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1.0 : f32 + %m = memref.alloc() : memref<4xf32, 1> + %v = memref.load %m[%c0] : memref<4xf32, 1> + memref.store %c1, %m[%c0] : memref<4xf32, 1> + return +} + +// CHECK-LABEL: func @memref_load_i4 +// CHECK-SAME: (%[[ARG:.*]]: index) +// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4xi8> +// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref, index, index, index +// CHECK-NEXT: %[[CI:.*]] = arith.constant 2 : index +// CHECK-NEXT: %[[SCALE:.*]] = arith.divui %[[ARG]], %[[CI]] : index +// CHECK-NEXT: %[[INDEX:.*]] = arith.muli %[[SCALE]], %[[STRIDES]] : index +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref to memref<4xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8 +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 +// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 +// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 +// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 +// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4 +// CHECK-NEXT: return +func.func @memref_load_i4(%arg0: index) { + %0 = memref.alloc() : memref<4xi4> + %1 = memref.load %0[%arg0] : memref<4xi4> + return +} + + +// CHECK-LABEL: func @memref_load_i4_rank2 +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) +// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4x4xi8> +// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[M]] : memref<4x4xi8> -> memref, index, index, index, index, index +// CHECK-NEXT: %[[CI:.*]] = arith.constant 2 : index +// CHECK-NEXT: %[[SCALE:.*]] = arith.divui %[[ARG1]], %[[CI]] : index +// CHECK-NEXT: %[[IDX1:.*]] = arith.muli %[[ARG0]], %[[STRIDES]]#0 : index +// CHECK-NEXT: %[[IDX2:.*]] = arith.muli %[[SCALE]], %[[STRIDES]]#1 : index +// CHECK-NEXT: %[[INDEX:.*]] = arith.addi %[[IDX1]], %[[IDX2]] : index +// CHECK-NEXT: %[[LSIZE:.*]] = arith.muli %[[SIZES]]#0, %[[SIZES]]#1 : index +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref to memref<16xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<16xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8 +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 +// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 +// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 +// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 +// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4 +// CHECK-NEXT: return +func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) { + %0 = memref.alloc() : memref<4x4xi4> + %1 = memref.load %0[%arg0,%arg1] : memref<4x4xi4> + return +} diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt --- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt +++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRMemRefTestPasses TestComposeSubView.cpp + TestEmulateNarrowInt.cpp TestMultiBuffer.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowInt.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowInt.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowInt.cpp @@ -0,0 +1,117 @@ +//===- TestEmulateNarrowInt.cpp - Test Narrow Int Emulation ------*- c++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/NarrowIntEmulationConverter.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +struct TestEmulateNarrowIntPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateNarrowIntPass) + + TestEmulateNarrowIntPass() = default; + TestEmulateNarrowIntPass(const TestEmulateNarrowIntPass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { return "test-emulate-narrow-int"; } + StringRef getDescription() const final { + return "Function pass to test Narrow Integer Emulation"; + } + + void runOnOperation() override { + if (!llvm::isPowerOf2_32(targetWideInt) || targetWideInt < 8) { + signalPassFailure(); + return; + } + + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + + arith::NarrowIntEmulationConverter typeConverter(targetWideInt); + + // Convert scalar type. + typeConverter.addConversion([this](IntegerType ty) -> std::optional { + unsigned width = ty.getWidth(); + if (width >= arithBitwidth) + return ty; + else + return IntegerType::get(ty.getContext(), arithBitwidth); + + return std::nullopt; + }); + + // Convert vector type. + typeConverter.addConversion([this](VectorType ty) -> std::optional { + auto intTy = dyn_cast(ty.getElementType()); + if (!intTy) + return ty; + + unsigned width = intTy.getWidth(); + if (width >= arithBitwidth) + return ty; + else + return VectorType::get( + to_vector(ty.getShape()), + IntegerType::get(ty.getContext(), arithBitwidth)); + + return std::nullopt; + }); + + memref::populateMemRefNarrowIntEmulationConversions(typeConverter); + ConversionTarget target(*ctx); + target.addDynamicallyLegalOp([&typeConverter](Operation *op) { + return typeConverter.isLegal(cast(op).getFunctionType()); + }); + auto opLegalCallback = [&typeConverter](Operation *op) { + return typeConverter.isLegal(op); + }; + target.addDynamicallyLegalOp(opLegalCallback); + target.addDynamicallyLegalDialect< + arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect>( + [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); + + RewritePatternSet patterns(ctx); + + arith::populateArithNarrowIntEmulationPatterns(typeConverter, patterns); + memref::populateMemRefNarrowIntEmulationPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); + } + + Option targetWideInt{*this, "memref-target-bits", + llvm::cl::desc("Target memref integer bit width"), + llvm::cl::init(8)}; + + Option arithBitwidth{*this, "arith-int-bitwidth", + llvm::cl::desc("Target arith integer bit width"), + llvm::cl::init(4)}; +}; +} // namespace + +namespace mlir::test { +void registerTestEmulateNarrowIntPass() { + PassRegistration(); +} +} // namespace mlir::test diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -86,6 +86,7 @@ void registerTestDialectConversionPasses(); void registerTestDominancePass(); void registerTestDynamicPipelinePass(); +void registerTestEmulateNarrowIntPass(); void registerTestExpandMathPass(); void registerTestFooAnalysisPass(); void registerTestComposeSubView(); @@ -204,6 +205,7 @@ mlir::test::registerTestDeadCodeAnalysisPass(); mlir::test::registerTestDominancePass(); mlir::test::registerTestDynamicPipelinePass(); + mlir::test::registerTestEmulateNarrowIntPass(); mlir::test::registerTestExpandMathPass(); mlir::test::registerTestFooAnalysisPass(); mlir::test::registerTestComposeSubView();