diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h b/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h @@ -0,0 +1,31 @@ +//===- NarrowTypeEmulationConverter.h - Type Converter for NTE -----*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_ +#define MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir::arith { +/// Converts narrow integer or float types that are not supported +/// by the target hardware to wider types. Currently, we only +/// handle power-of-two integer types and convert them to wider +/// integers that are equal or larger than 8 bits. +class NarrowTypeEmulationConverter : public TypeConverter { +public: + explicit NarrowTypeEmulationConverter(unsigned targetBitwidth); + + unsigned getLoadStoreBitwidth() const { return loadStoreBitwidth; } + +private: + unsigned loadStoreBitwidth; +}; +} // namespace mlir::arith + +#endif // MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_ diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -22,6 +22,7 @@ #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" class WideIntEmulationConverter; +class NarrowTypeEmulationConverter; /// Create a pass to bufferize Arith ops. std::unique_ptr createArithBufferizePass(); @@ -35,6 +36,12 @@ void populateArithWideIntEmulationPatterns( WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns); +/// Adds patterns to emulate narrow Arith and Function ops into wide +/// supported types. Users need to add conversions about the computation +/// domain of narrow types. +void populateArithNarrowTypeEmulationPatterns( + NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns); + /// Add patterns to expand Arith ceil/floor division ops. void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -25,6 +25,7 @@ namespace arith { class WideIntEmulationConverter; +class NarrowTypeEmulationConverter; } // namespace arith namespace memref { @@ -73,6 +74,17 @@ void populateMemRefWideIntEmulationConversions( arith::WideIntEmulationConverter &typeConverter); +/// Appends patterns for emulating memref operations over narrow types with ops +/// over wider types. +void populateMemRefNarrowTypeEmulationPatterns( + arith::NarrowTypeEmulationConverter &typeConverter, + RewritePatternSet &patterns); + +/// Appends type conversions for emulating memref operations over narrow types +/// with ops over wider types. +void populateMemRefNarrowTypeEmulationConversions( + arith::NarrowTypeEmulationConverter &typeConverter); + /// Transformation to do multi-buffering/array expansion to remove dependencies /// on the temporary allocation between consecutive loop iterations. /// It returns the new allocation if the original allocation was multi-buffered diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp EmulateWideInt.cpp + EmulateNarrowType.cpp ExpandOps.cpp IntNarrowing.cpp IntRangeOptimizations.cpp diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp @@ -0,0 +1,61 @@ +//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APInt.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Public Interface Definition +//===----------------------------------------------------------------------===// + +arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter( + unsigned targetBitwidth) + : loadStoreBitwidth(targetBitwidth) { + assert(llvm::isPowerOf2_32(targetBitwidth) && + "Only power-of-two integers are supported"); + + // Allow unknown types. + addConversion([](Type ty) -> std::optional { return ty; }); + + // Function case. + addConversion([this](FunctionType ty) -> std::optional { + SmallVector inputs; + if (failed(convertTypes(ty.getInputs(), inputs))) + return std::nullopt; + + SmallVector results; + if (failed(convertTypes(ty.getResults(), results))) + return std::nullopt; + + return FunctionType::get(ty.getContext(), inputs, results); + }); +} + +void arith::populateArithNarrowTypeEmulationPatterns( + NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns) { + // Populate `func.*` conversion patterns. + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ ExpandOps.cpp ExpandStridedMetadata.cpp EmulateWideInt.cpp + EmulateNarrowType.cpp ExtractAddressComputations.cpp FoldMemRefAliasOps.cpp IndependenceTransforms.cpp diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -0,0 +1,315 @@ +//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +/// The emulation only works on 1D memref types. +/// To make this work on N-D memref, we need to linearize the offset. +/// +/// For example, to emulate i4 to i8, the following op: +/// +/// %0 = memref.load %arg0[%v0, %v1] : +/// memref> +/// +/// can be replaced with +/// +/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 +/// +/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1 +/// %linearized_size = %size0 * %size1 +/// %scaled_linear_offset = %linearized_offset / 8 * 4 +/// %scaled_base_offset = %offset / 8 * 4 +/// +/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset], +/// sizes = [%linearized_size], strides = [%stride#1] +/// +/// %new_load = memref.load %linearized[%scaled_linear_offset] : +/// memref> + +static Value +linearizeMemrefLoad(Location loc, MemRefType sourceType, int srcBits, + int dstBits, SmallVector indices, + memref::ExtractStridedMetadataOp stridedMetadata, + OpBuilder &builder) { + auto srcElementType = sourceType.getElementType(); + unsigned sourceRank = indices.size(); + + Value baseBuffer = stridedMetadata.getBaseBuffer(); + SmallVector baseSizes = stridedMetadata.getSizes(); + SmallVector baseStrides = stridedMetadata.getStrides(); + Value baseOffset = stridedMetadata.getOffset(); + assert(indices.size() == baseStrides.size()); + + // Create the affine symbols and values for linearization. + SmallVector symbols(2 * sourceRank + 2); + bindSymbolsList(builder.getContext(), MutableArrayRef{symbols}); + symbols[0] = builder.getAffineSymbolExpr(0); + AffineExpr addMulMap = symbols.front(); + AffineExpr mulMap = symbols.front(); + + SmallVector offsetValues(2 * sourceRank + 2); + offsetValues[0] = builder.getIndexAttr(0); + SmallVector sizeValues(sourceRank + 1); + sizeValues[0] = builder.getIndexAttr(1); + + for (unsigned i = 0; i < sourceRank; ++i) { + unsigned offsetIdx = 2 * i + 1; + addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1]; + offsetValues[offsetIdx] = indices[i]; + offsetValues[offsetIdx + 1] = baseStrides[i]; + + unsigned sizeIdx = i + 1; + mulMap = mulMap * symbols[sizeIdx]; + sizeValues[sizeIdx] = baseSizes[i]; + } + + // Adjust linearizedOffset by the scale factor (dstBits / srcBits). + OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits); + AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back()); + offsetValues.back() = scaler; + + OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply( + builder, loc, scaledAddMulMap, offsetValues); + OpFoldResult linearizedSize = + affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues); + + // Adjust baseOffset by the scale factor (dstBits / srcBits). + AffineExpr s0, s1; + bindSymbols(builder.getContext(), s0, s1); + OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply( + builder, loc, s0.floorDiv(s1), {baseOffset, scaler}); + + // Flatten n-D MemRef to 1-D MemRef. + auto layoutAttr = StridedLayoutAttr::get( + sourceType.getContext(), ShapedType::kDynamic, {ShapedType::kDynamic}); + int64_t staticShape = sourceType.hasStaticShape() + ? sourceType.getNumElements() + : ShapedType::kDynamic; + auto flattenMemrefType = MemRefType::get( + staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace()); + + auto reinterpret = builder.create( + loc, flattenMemrefType, baseBuffer, + getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset), + getValueOrCreateConstantIndexOp(builder, loc, linearizedSize), + baseStrides.back()); + + return builder.create( + loc, srcElementType, reinterpret.getResult(), + getValueOrCreateConstantIndexOp(builder, loc, linearizedOffset)); +} + +/// When data is loaded/stored in `targetBits` granularity, but is used in +/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is +/// treated as an array of elements of width `sourceBits`. +/// Return the bit offset of the value at position `srcIdx`. For example, if +/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is +/// located at (x % 2) * 4. Because there are two elements in one i8, and one +/// element has 4 bits. +static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, + int targetBits, OpBuilder &builder) { + assert(targetBits % sourceBits == 0); + IntegerType targetType = builder.getIntegerType(targetBits); + IntegerAttr idxAttr = + builder.getIntegerAttr(targetType, targetBits / sourceBits); + auto idx = builder.create(loc, targetType, idxAttr); + IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); + auto srcBitsValue = + builder.create(loc, targetType, srcBitsAttr); + auto m = builder.create(loc, srcIdx, idx); + return builder.create(loc, targetType, m, srcBitsValue); +} + +namespace { + +//===----------------------------------------------------------------------===// +// ConvertMemRefAlloc +//===----------------------------------------------------------------------===// + +struct ConvertMemRefAlloc final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) { + return rewriter.notifyMatchFailure( + op->getLoc(), + llvm::formatv("failed to convert memref type: {0}", op.getType())); + } + + rewriter.replaceOpWithNewOp( + op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(), + adaptor.getAlignmentAttr()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertMemRefAssumeAlignment +//===----------------------------------------------------------------------===// + +struct ConvertMemRefAssumeAlignment final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getMemref().getType()); + if (!newTy) { + return rewriter.notifyMatchFailure( + op->getLoc(), llvm::formatv("failed to convert memref type: {0}", + op.getMemref().getType())); + } + + rewriter.replaceOpWithNewOp( + op, adaptor.getMemref(), adaptor.getAlignmentAttr()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertMemRefLoad +//===----------------------------------------------------------------------===// + +struct ConvertMemRefLoad final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getMemRefType()); + if (!newTy) { + return rewriter.notifyMatchFailure( + op->getLoc(), llvm::formatv("failed to convert memref type: {0}", + op.getMemRefType())); + } + + if (op.getMemRefType() == newTy) + return failure(); + + auto loc = op.getLoc(); + auto sourceType = cast(adaptor.getMemref().getType()); + unsigned sourceRank = sourceType.getRank(); + SmallVector indices = adaptor.getIndices(); + assert(indices.size() == sourceRank); + + auto srcElementType = sourceType.getElementType(); + auto oldElementType = op.getMemRefType().getElementType(); + int srcBits = oldElementType.getIntOrFloatBitWidth(); + int dstBits = srcElementType.getIntOrFloatBitWidth(); + if (dstBits % srcBits != 0) { + return rewriter.notifyMatchFailure( + op, "only dstBits % srcBits == 0 supported"); + } + + auto stridedMetadata = rewriter.create( + loc, adaptor.getMemref()); + + Value newLoad, lastIdx; + if (sourceRank == 0) { + newLoad = rewriter.create( + loc, srcElementType, adaptor.getMemref(), adaptor.getIndices()); + + lastIdx = stridedMetadata.getOffset(); + } else { + newLoad = linearizeMemrefLoad(loc, sourceType, srcBits, dstBits, indices, + stridedMetadata, rewriter); + + lastIdx = adaptor.getIndices().back(); + } + + // Get the offset and shift the bits to the rightmost. + // Note, currently only the big-endian is supported. + auto castLastIdx = + rewriter.create(loc, srcElementType, lastIdx); + + Value BitwidthOffset = + getOffsetForBitwidth(loc, castLastIdx, srcBits, dstBits, rewriter); + auto bitsLoad = + rewriter.create(loc, newLoad, BitwidthOffset); + + // Get the corresponding bits. If the arith computation bitwidth equals + // to the emulated bitwidth, we apply a mask to extract the low bits. + // It is not clear if this case actually happens in practice, but we keep + // the operations just in case. Otherwise, if the arith computation bitwidth + // is different from the emulated bitwidth we truncate the result. + Operation *result; + auto resultTy = getTypeConverter()->convertType(oldElementType); + if (resultTy == srcElementType) { + auto mask = rewriter.create( + loc, srcElementType, + rewriter.getIntegerAttr(srcElementType, (1 << srcBits) - 1)); + + result = rewriter.create(loc, bitsLoad, mask); + } else { + result = rewriter.create(loc, resultTy, bitsLoad); + } + + rewriter.replaceOp(op, result->getResult(0)); + return success(); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Public Interface Definition +//===----------------------------------------------------------------------===// + +void memref::populateMemRefNarrowTypeEmulationPatterns( + arith::NarrowTypeEmulationConverter &typeConverter, + RewritePatternSet &patterns) { + + // Populate `memref.*` conversion patterns. + patterns + .add( + typeConverter, patterns.getContext()); +} + +void memref::populateMemRefNarrowTypeEmulationConversions( + arith::NarrowTypeEmulationConverter &typeConverter) { + typeConverter.addConversion( + [&typeConverter](MemRefType ty) -> std::optional { + auto intTy = dyn_cast(ty.getElementType()); + if (!intTy) + return ty; + + unsigned width = intTy.getWidth(); + unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth(); + if (width >= loadStoreWidth) + return ty; + + auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth, + intTy.getSignedness()); + if (!newElemTy) + return std::nullopt; + + return ty.cloneWith(std::nullopt, newElemTy); + }); +} diff --git a/mlir/test/Dialect/Arith/emulate-narrow-type.mlir b/mlir/test/Dialect/Arith/emulate-narrow-type.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arith/emulate-narrow-type.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8" %s | FileCheck %s + +// Expect no conversions, f32 is not an integer type. +// CHECK-LABEL: func @identity_f32 +// CHECK-SAME: ([[ARG:%.+]]: f32) -> f32 +// CHECK-NEXT: return [[ARG]] : f32 +func.func @identity_f32(%a : f32) -> f32 { + return %a : f32 +} + +// Expect no conversions, i32 is supported. +// CHECK-LABEL: func @identity_i32 +// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: return [[ARG]] : vector<2xi32> +func.func @identity_i32(%a : vector<2xi32>) -> vector<2xi32> { + return %a : vector<2xi32> +} + +// CHECK-LABEL: func @identity_scalar +// CHECK-SAME: ([[ARG:%.+]]: i8) -> i8 +// CHECK-NEXT: return [[ARG]] : i8 +func.func @identity_scalar(%x : i4) -> i4 { + return %x : i4 +} + +// CHECK-LABEL: func @identity_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8> +// CHECK-NEXT: return [[ARG]] : vector<4xi8> +func.func @identity_vector(%x : vector<4xi4>) -> vector<4xi4> { + return %x : vector<4xi4> +} + +// CHECK-LABEL: func @identity_vector2d +// CHECK-SAME: ([[ARG:%.+]]: vector<3x4xi8>) -> vector<3x4xi8> +// CHECK-NEXT: return [[ARG]] : vector<3x4xi8> +func.func @identity_vector2d(%x : vector<3x4xi4>) -> vector<3x4xi4> { + return %x : vector<3x4xi4> +} + +// CHECK-LABEL: func @call +// CHECK-SAME: ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8> +// CHECK-NEXT: [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<4xi8>) -> vector<4xi8> +// CHECK-NEXT: return [[RES]] : vector<4xi8> +func.func @call(%a : vector<4xi4>) -> vector<4xi4> { + %res = func.call @identity_vector(%a) : (vector<4xi4>) -> vector<4xi4> + return %res : vector<4xi4> +} diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir @@ -0,0 +1,107 @@ +// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=4 memref-load-bitwidth=8" %s | FileCheck %s + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> + +// Expect no conversions, i32 is supported. +// CHECK-LABEL: func @memref_i32 +// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi32, 1> +// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi32, 1> +// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi32, 1> +// CHECK-NEXT: return +func.func @memref_i32() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : i32 + %m = memref.alloc() : memref<4xi32, 1> + %v = memref.load %m[%c0] : memref<4xi32, 1> + memref.store %c1, %m[%c0] : memref<4xi32, 1> + return +} + +// ----- + +// Expect no conversions, f32 is not an integer type. +// CHECK-LABEL: func @memref_f32 +// CHECK: [[M:%.+]] = memref.alloc() : memref<4xf32, 1> +// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xf32, 1> +// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xf32, 1> +// CHECK-NEXT: return +func.func @memref_f32() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1.0 : f32 + %m = memref.alloc() : memref<4xf32, 1> + %v = memref.load %m[%c0] : memref<4xf32, 1> + memref.store %c1, %m[%c0] : memref<4xf32, 1> + return +} + +// ----- + +// CHECK-LABEL: func @memref_load_i4_zero_rank +// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref +// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[M]] : memref -> memref, index +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[M]][] : memref +// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[OFFSET]] : index to i8 +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 +// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 +// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 +// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 +// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4 +// CHECK-NEXT: return +func.func @memref_load_i4_zero_rank() { + %0 = memref.alloc() : memref + %1 = memref.load %0[] : memref + return +} + +// ----- + +// CHECK-LABEL: func @memref_load_i4 +// CHECK-SAME: (%[[ARG:.*]]: index) +// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4xi8> +// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref, index, index, index +// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]] +// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref to memref<4xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8 +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 +// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 +// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 +// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 +// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4 +// CHECK-NEXT: return +func.func @memref_load_i4(%arg0: index) { + %0 = memref.alloc() : memref<4xi4> + %1 = memref.load %0[%arg0] : memref<4xi4> + return +} + +// ----- + +// CHECK-LABEL: func @memref_load_i4_rank2 +// CHECK-SAME: (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) +// CHECK-NEXT: memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8> +// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref, index, index, index, index, index +// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] +// CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1] +// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref to memref<512xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8 +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 +// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 +// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 +// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 +// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4 +// CHECK-NEXT: return +func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) { + memref.assume_alignment %0, 64 : memref<4x128xi4> + %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4> + return +} diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8 memref-load-bitwidth=8" %s | FileCheck %s + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> + +// Expect no conversions. +// CHECK-LABEL: func @memref_i8 +// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi8, 1> +// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi8, 1> +// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi8, 1> +// CHECK-NEXT: return +func.func @memref_i8() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : i8 + %m = memref.alloc() : memref<4xi8, 1> + %v = memref.load %m[%c0] : memref<4xi8, 1> + memref.store %c1, %m[%c0] : memref<4xi8, 1> + return +} + +// ----- + +// CHECK-LABEL: func @memref_load_i4 +// CHECK-SAME: (%[[ARG:.*]]: index) +// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4xi8> +// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref, index, index, index +// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]] +// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref to memref<4xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8 +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 +// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 +// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 +// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 +// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 +// CHECK-NEXT: %[[MASK:.*]] = arith.constant 15 : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8 +// CHECK-NEXT: return +func.func @memref_load_i4(%arg0: index) { + %0 = memref.alloc() : memref<4xi4> + %1 = memref.load %0[%arg0] : memref<4xi4> + return +} + +// ----- + +// CHECK-LABEL: func @memref_load_i4_rank2 +// CHECK-SAME: (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) +// CHECK-NEXT: memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8> +// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref, index, index, index, index, index +// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] +// CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1] +// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref to memref<512xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8 +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 +// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 +// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 +// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 +// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 +// CHECK-NEXT: %[[MASK:.*]] = arith.constant 15 : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8 +// CHECK-NEXT: return +func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) { + memref.assume_alignment %0, 64 : memref<4x128xi4> + %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4> + return +} diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt --- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt +++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRMemRefTestPasses TestComposeSubView.cpp + TestEmulateNarrowType.cpp TestMultiBuffer.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp @@ -0,0 +1,118 @@ +//===- TestEmulateNarrowType.cpp - Test Narrow Type Emulation ------*- c++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +struct TestEmulateNarrowTypePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateNarrowTypePass) + + TestEmulateNarrowTypePass() = default; + TestEmulateNarrowTypePass(const TestEmulateNarrowTypePass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + StringRef getArgument() const final { return "test-emulate-narrow-int"; } + StringRef getDescription() const final { + return "Function pass to test Narrow Integer Emulation"; + } + + void runOnOperation() override { + if (!llvm::isPowerOf2_32(loadStoreEmulateBitwidth) || + loadStoreEmulateBitwidth < 8) { + signalPassFailure(); + return; + } + + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + + arith::NarrowTypeEmulationConverter typeConverter(loadStoreEmulateBitwidth); + + // Convert scalar type. + typeConverter.addConversion([this](IntegerType ty) -> std::optional { + unsigned width = ty.getWidth(); + if (width >= arithComputeBitwidth) + return ty; + + return IntegerType::get(ty.getContext(), arithComputeBitwidth); + }); + + // Convert vector type. + typeConverter.addConversion([this](VectorType ty) -> std::optional { + auto intTy = dyn_cast(ty.getElementType()); + if (!intTy) + return ty; + + unsigned width = intTy.getWidth(); + if (width >= arithComputeBitwidth) + return ty; + + return VectorType::get( + to_vector(ty.getShape()), + IntegerType::get(ty.getContext(), arithComputeBitwidth)); + }); + + memref::populateMemRefNarrowTypeEmulationConversions(typeConverter); + ConversionTarget target(*ctx); + target.addDynamicallyLegalOp([&typeConverter](Operation *op) { + return typeConverter.isLegal(cast(op).getFunctionType()); + }); + auto opLegalCallback = [&typeConverter](Operation *op) { + return typeConverter.isLegal(op); + }; + target.addDynamicallyLegalOp(opLegalCallback); + target.addDynamicallyLegalDialect< + arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect, + affine::AffineDialect>( + [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); + + RewritePatternSet patterns(ctx); + + arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns); + memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); + } + + Option loadStoreEmulateBitwidth{ + *this, "memref-load-bitwidth", + llvm::cl::desc("memref load/store emulation bit width"), + llvm::cl::init(8)}; + + Option arithComputeBitwidth{ + *this, "arith-compute-bitwidth", + llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)}; +}; +} // namespace + +namespace mlir::test { +void registerTestEmulateNarrowTypePass() { + PassRegistration(); +} +} // namespace mlir::test diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -87,6 +87,7 @@ void registerTestDialectConversionPasses(); void registerTestDominancePass(); void registerTestDynamicPipelinePass(); +void registerTestEmulateNarrowTypePass(); void registerTestExpandMathPass(); void registerTestFooAnalysisPass(); void registerTestComposeSubView(); @@ -205,6 +206,7 @@ mlir::test::registerTestDeadCodeAnalysisPass(); mlir::test::registerTestDominancePass(); mlir::test::registerTestDynamicPipelinePass(); + mlir::test::registerTestEmulateNarrowTypePass(); mlir::test::registerTestExpandMathPass(); mlir::test::registerTestFooAnalysisPass(); mlir::test::registerTestComposeSubView();