diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h --- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h +++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h @@ -28,36 +28,37 @@ /// contiguous chunk of memory. bool isStaticShapeAndContiguousRowMajor(MemRefType type); -/// Returns the flattened 1-D memref and linearized offset for narrow type -/// emulation. -/// -/// The emulation only works on 1D memref types. To make this work on N-D -/// memref, we need to linearize the offset. -/// -/// For example, to emulate i4 to i8, the following op: -/// -/// %0 = memref.load %arg0[%v0, %v1] : -/// memref> -/// -/// can be replaced with -/// -/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 -/// -/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1 -/// %linearized_size = %size0 * %size1 -/// %scaled_linear_offset = %linearized_offset / 8 * 4 -/// %scaled_base_offset = %offset / 8 * 4 -/// -/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset], -/// sizes = [%linearized_size], strides = [%stride#1] -/// -/// %new_load = memref.load %linearized[%scaled_linear_offset] : -/// memref> -std::pair -getLinearizeMemRefAndOffset(Location loc, MemRefType sourceType, int srcBits, - int dstBits, SmallVector indices, - memref::ExtractStridedMetadataOp stridedMetadata, - OpBuilder &builder); +/// For a `memref` with `offset`, `sizes` and `strides`, returns the +/// offset and size to use for the linearized `memref`. +/// - If the linearization is done for emulating load/stores of +/// element type with bitwidth `srcBits` using element type with +/// bitwidth `dstBits`, the linearized offset and size are +/// scaled down by `dstBits`/`srcBits`. +/// - If `indices` is provided, it represents the position in the +/// original `memref` being accessed. The method then returns the +/// index to use in the linearized `memref`. The linearized index +/// is also scaled down by `dstBits`/`srcBits`. If `indices` is not provided +/// 0, is returned for the linearized index. +struct LinearizedMemRefInfo { + OpFoldResult linearizedOffset; + OpFoldResult linearizedSize; +}; +std::pair getLinearizedMemRefOffsetAndSize( + OpBuilder &builder, Location loc, int srcBits, int dstBits, + OpFoldResult offset, ArrayRef sizes, + ArrayRef strides, ArrayRef indices = {}); + +/// For a `memref` with `offset` and `sizes`, returns the +/// offset and size to use for the linearized `memref`, assuming that +/// the strides are computed from a row-major ordering of the sizes; +/// - If the linearization is done for emulating load/stores of +/// element type with bitwidth `srcBits` using element type with +/// bitwidth `dstBits`, the linearized offset and size are +/// scaled down by `dstBits`/`srcBits`. +LinearizedMemRefInfo +getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, + int dstBits, OpFoldResult offset, + ArrayRef sizes); } // namespace memref } // namespace mlir diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -35,18 +35,18 @@ /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is /// located at (x % 2) * 4. Because there are two elements in one i8, and one /// element has 4 bits. -static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, - int targetBits, OpBuilder &builder) { +static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, + int sourceBits, int targetBits, + OpBuilder &builder) { assert(targetBits % sourceBits == 0); - IntegerType targetType = builder.getIntegerType(targetBits); - IntegerAttr idxAttr = - builder.getIntegerAttr(targetType, targetBits / sourceBits); - auto idx = builder.create(loc, targetType, idxAttr); - IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); - auto srcBitsValue = - builder.create(loc, targetType, srcBitsAttr); - auto m = builder.create(loc, srcIdx, idx); - return builder.create(loc, targetType, m, srcBitsValue); + AffineExpr s0; + bindSymbols(builder.getContext(), s0); + int scaleFactor = targetBits / sourceBits; + OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply( + builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx}); + Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal); + IntegerType dstType = builder.getIntegerType(targetBits); + return builder.create(loc, dstType, bitOffset); } namespace { @@ -61,15 +61,44 @@ LogicalResult matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type newTy = getTypeConverter()->convertType(op.getType()); - if (!newTy) { + MemRefType currentType = op.getMemref().getType().cast(); + MemRefType newResultType = + getTypeConverter()->convertType(op.getType()).dyn_cast(); + if (!newResultType) { return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to convert memref type: {0}", op.getType())); } + // Special case zero-rank memrefs. + if (currentType.getRank() == 0) { + rewriter.replaceOpWithNewOp( + op, newResultType, ValueRange{}, adaptor.getSymbolOperands(), + adaptor.getAlignmentAttr()); + return success(); + } + + Location loc = op.getLoc(); + OpFoldResult zero = rewriter.getIndexAttr(0); + SmallVector indices(currentType.getRank(), zero); + + // Get linearized type. + int srcBits = currentType.getElementType().getIntOrFloatBitWidth(); + int dstBits = newResultType.getElementType().getIntOrFloatBitWidth(); + OpFoldResult elementOffset = rewriter.getIndexAttr(0); + SmallVector sizes = op.getMixedSizes(); + + memref::LinearizedMemRefInfo linearizedMemRefInfo = + memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, srcBits, + dstBits, zero, sizes); + SmallVector dynamicLinearizedSize; + if (!newResultType.hasStaticShape()) { + dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp( + rewriter, loc, linearizedMemRefInfo.linearizedSize)); + } + rewriter.replaceOpWithNewOp( - op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(), + op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(), adaptor.getAlignmentAttr()); return success(); } @@ -109,61 +138,56 @@ LogicalResult matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type newTy = getTypeConverter()->convertType(op.getMemRefType()); - if (!newTy) { - return rewriter.notifyMatchFailure( - op->getLoc(), llvm::formatv("failed to convert memref type: {0}", - op.getMemRefType())); + auto convertedType = adaptor.getMemref().getType().cast(); + // Special case 0-rank memref loads. + if (convertedType.getRank() == 0) { + rewriter.replaceOpWithNewOp(op, adaptor.getMemref(), + ValueRange{}); + return success(); } - if (op.getMemRefType() == newTy) - return failure(); - auto loc = op.getLoc(); - auto sourceType = cast(adaptor.getMemref().getType()); - unsigned sourceRank = sourceType.getRank(); - SmallVector indices = adaptor.getIndices(); - assert(indices.size() == sourceRank); + SmallVector indices = getAsOpFoldResult(adaptor.getIndices()); - auto srcElementType = sourceType.getElementType(); + auto convertedElementType = convertedType.getElementType(); auto oldElementType = op.getMemRefType().getElementType(); int srcBits = oldElementType.getIntOrFloatBitWidth(); - int dstBits = srcElementType.getIntOrFloatBitWidth(); + int dstBits = convertedElementType.getIntOrFloatBitWidth(); if (dstBits % srcBits != 0) { return rewriter.notifyMatchFailure( op, "only dstBits % srcBits == 0 supported"); } - auto stridedMetadata = rewriter.create( - loc, adaptor.getMemref()); - - Value newLoad, lastIdx; - if (sourceRank == 0) { - newLoad = rewriter.create( - loc, srcElementType, adaptor.getMemref(), adaptor.getIndices()); - - lastIdx = stridedMetadata.getOffset(); - } else { - auto [reinterpret, linearizedOffset] = - memref::getLinearizeMemRefAndOffset(loc, sourceType, srcBits, dstBits, - adaptor.getIndices(), - stridedMetadata, rewriter); - - newLoad = rewriter.create(loc, srcElementType, - reinterpret, linearizedOffset); - - lastIdx = adaptor.getIndices().back(); - } + auto stridedMetadata = + rewriter.create(loc, op.getMemRef()); + + // Linearize the indices of the original load instruction. Do not account + // for the scaling yet. This will be accounted for later. + OpFoldResult linearizedIndices; + std::tie(std::ignore, linearizedIndices) = + memref::getLinearizedMemRefOffsetAndSize( + rewriter, loc, srcBits, srcBits, + stridedMetadata.getConstifiedMixedOffset(), + stridedMetadata.getConstifiedMixedSizes(), + stridedMetadata.getConstifiedMixedStrides(), indices); + + AffineExpr s0; + bindSymbols(rewriter.getContext(), s0); + int64_t scaler = dstBits / srcBits; + OpFoldResult scaledLinearizedIndices = + affine::makeComposedFoldedAffineApply( + rewriter, loc, s0.floorDiv(scaler), {linearizedIndices}); + Value newLoad = rewriter.create( + loc, adaptor.getMemref(), + getValueOrCreateConstantIndexOp(rewriter, loc, + scaledLinearizedIndices)); // Get the offset and shift the bits to the rightmost. // Note, currently only the big-endian is supported. - auto castLastIdx = - rewriter.create(loc, srcElementType, lastIdx); - - Value BitwidthOffset = - getOffsetForBitwidth(loc, castLastIdx, srcBits, dstBits, rewriter); + Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits, + dstBits, rewriter); auto bitsLoad = - rewriter.create(loc, newLoad, BitwidthOffset); + rewriter.create(loc, newLoad, bitwidthOffset); // Get the corresponding bits. If the arith computation bitwidth equals // to the emulated bitwidth, we apply a mask to extract the low bits. @@ -172,10 +196,10 @@ // is different from the emulated bitwidth we truncate the result. Operation *result; auto resultTy = getTypeConverter()->convertType(oldElementType); - if (resultTy == srcElementType) { + if (resultTy == convertedElementType) { auto mask = rewriter.create( - loc, srcElementType, - rewriter.getIntegerAttr(srcElementType, (1 << srcBits) - 1)); + loc, convertedElementType, + rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1)); result = rewriter.create(loc, bitsLoad, mask); } else { @@ -200,6 +224,25 @@ patterns .add( typeConverter, patterns.getContext()); + memref::populateResolveExtractStridedMetadataPatterns(patterns); +} + +static SmallVector getLinearizedShape(MemRefType ty, int srcBits, + int dstBits) { + if (ty.getRank() == 0) + return {}; + + int64_t linearizedShape = 1; + for (auto shape : ty.getShape()) { + if (shape == ShapedType::kDynamic) + return {ShapedType::kDynamic}; + linearizedShape *= shape; + } + int scale = dstBits / srcBits; + // Scale the size to the ceilDiv(linearizedShape, scale) + // to accomodate all the values. + linearizedShape = (linearizedShape + scale - 1) / scale; + return {linearizedShape}; } void memref::populateMemRefNarrowTypeEmulationConversions( @@ -215,11 +258,26 @@ if (width >= loadStoreWidth) return ty; + // Currently only handle innermost stride being 1, checking + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(ty, strides, offset))) + return std::nullopt; + if (!strides.empty() && strides.back() != 1) + return std::nullopt; + auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth, intTy.getSignedness()); if (!newElemTy) return std::nullopt; - return ty.cloneWith(std::nullopt, newElemTy); + StridedLayoutAttr layoutAttr; + if (offset != 0) { + layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, + ArrayRef{1}); + } + + return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth), + newElemTy, layoutAttr, ty.getMemorySpace()); }); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -687,13 +687,17 @@ auto baseBufferType = cast(op.getBaseBuffer().getType()); int64_t offset = 0; - if (allocLikeOp.getType() == baseBufferType) - results.push_back(allocLikeOp); - else - results.push_back(rewriter.create( - loc, baseBufferType, allocLikeOp, offset, - /*sizes=*/ArrayRef(), - /*strides=*/ArrayRef())); + if (op.getBaseBuffer().use_empty()) { + results.push_back(nullptr); + } else { + if (allocLikeOp.getType() == baseBufferType) + results.push_back(allocLikeOp); + else + results.push_back(rewriter.create( + loc, baseBufferType, allocLikeOp, offset, + /*sizes=*/ArrayRef(), + /*strides=*/ArrayRef())); + } // Offset. results.push_back(rewriter.create(loc, offset)); diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -46,79 +46,78 @@ return curDim < 0; } -std::pair -getLinearizeMemRefAndOffset(Location loc, MemRefType sourceType, int srcBits, - int dstBits, SmallVector indices, - memref::ExtractStridedMetadataOp stridedMetadata, - OpBuilder &builder) { - auto srcElementType = sourceType.getElementType(); - unsigned sourceRank = indices.size(); - - Value baseBuffer = stridedMetadata.getBaseBuffer(); - SmallVector baseSizes = stridedMetadata.getSizes(); - SmallVector baseStrides = stridedMetadata.getStrides(); - Value baseOffset = stridedMetadata.getOffset(); - assert(indices.size() == baseStrides.size()); +std::pair getLinearizedMemRefOffsetAndSize( + OpBuilder &builder, Location loc, int srcBits, int dstBits, + OpFoldResult offset, ArrayRef sizes, + ArrayRef strides, ArrayRef indices) { + unsigned sourceRank = sizes.size(); + assert(sizes.size() == strides.size() && + "expected as many sizes as strides for a memref"); + SmallVector indicesVec = llvm::to_vector(indices); + if (indices.empty()) + indicesVec.resize(sourceRank, builder.getIndexAttr(0)); + assert(indicesVec.size() == strides.size() && + "expected as many indices as rank of memref"); // Create the affine symbols and values for linearization. - SmallVector symbols(2 * sourceRank + 2); + SmallVector symbols(2 * sourceRank); bindSymbolsList(builder.getContext(), MutableArrayRef{symbols}); - symbols[0] = builder.getAffineSymbolExpr(0); - AffineExpr addMulMap = symbols.front(); - AffineExpr mulMap = symbols.front(); + AffineExpr addMulMap = builder.getAffineConstantExpr(0); + AffineExpr mulMap = builder.getAffineConstantExpr(1); - SmallVector offsetValues(2 * sourceRank + 2); - offsetValues[0] = builder.getIndexAttr(0); - SmallVector sizeValues(sourceRank + 1); - sizeValues[0] = builder.getIndexAttr(1); + SmallVector offsetValues(2 * sourceRank); + SmallVector sizeValues(sourceRank); for (unsigned i = 0; i < sourceRank; ++i) { - unsigned offsetIdx = 2 * i + 1; + unsigned offsetIdx = 2 * i; addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1]; - offsetValues[offsetIdx] = indices[i]; - offsetValues[offsetIdx + 1] = baseStrides[i]; + offsetValues[offsetIdx] = indicesVec[i]; + offsetValues[offsetIdx + 1] = strides[i]; - unsigned sizeIdx = i + 1; - mulMap = mulMap * symbols[sizeIdx]; - sizeValues[sizeIdx] = baseSizes[i]; + mulMap = mulMap * symbols[i]; } - // Adjust linearizedOffset by the scale factor (dstBits / srcBits). - OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits); - AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back()); - offsetValues.back() = scaler; + // Adjust linearizedIndices, size and offset by the scale factor (dstBits / + // srcBits). + int64_t scaler = dstBits / srcBits; + addMulMap = addMulMap.floorDiv(scaler); + mulMap = mulMap.floorDiv(scaler); - OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply( - builder, loc, scaledAddMulMap, offsetValues); + OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply( + builder, loc, addMulMap, offsetValues); OpFoldResult linearizedSize = - affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues); + affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes); // Adjust baseOffset by the scale factor (dstBits / srcBits). - AffineExpr s0, s1; - bindSymbols(builder.getContext(), s0, s1); + AffineExpr s0; + bindSymbols(builder.getContext(), s0); OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply( - builder, loc, s0.floorDiv(s1), {baseOffset, scaler}); - - // Flatten n-D MemRef to 1-D MemRef. - std::optional stride = - getConstantIntValue(stridedMetadata.getConstifiedMixedStrides().back()); - auto layoutAttr = - StridedLayoutAttr::get(sourceType.getContext(), ShapedType::kDynamic, - {stride ? stride.value() : ShapedType::kDynamic}); - int64_t staticShape = sourceType.hasStaticShape() - ? sourceType.getNumElements() - : ShapedType::kDynamic; - auto flattenMemrefType = MemRefType::get( - staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace()); - - auto reinterpret = builder.create( - loc, flattenMemrefType, baseBuffer, - getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset), - getValueOrCreateConstantIndexOp(builder, loc, linearizedSize), - baseStrides.back()); - - return std::make_pair(reinterpret, getValueOrCreateConstantIndexOp( - builder, loc, linearizedOffset)); + builder, loc, s0.floorDiv(scaler), {offset}); + + return {{adjustBaseOffset, linearizedSize}, linearizedIndices}; +} + +LinearizedMemRefInfo +getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, + int dstBits, OpFoldResult offset, + ArrayRef sizes) { + SmallVector strides(sizes.size()); + if (sizes.size() > 0) { + strides.back() = builder.getIndexAttr(1); + AffineExpr s0, s1; + bindSymbols(builder.getContext(), s0, s1); + for (int index = sizes.size() - 1; index > 0; --index) { + strides[index - 1] = affine::makeComposedFoldedAffineApply( + builder, loc, s0 * s1, + ArrayRef{strides[index], sizes[index]}); + } + } + + LinearizedMemRefInfo linearizedMemRefInfo; + std::tie(linearizedMemRefInfo, std::ignore) = + getLinearizedMemRefOffsetAndSize(builder, loc, srcBits, dstBits, offset, + sizes, strides); + return linearizedMemRefInfo; } } // namespace memref diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -69,19 +69,24 @@ if (origElements % scale != 0) return failure(); - auto stridedMetadata = rewriter.create( - loc, adaptor.getBase()); - - auto [reinterpret, linearizedOffset] = memref::getLinearizeMemRefAndOffset( - loc, sourceType, srcBits, dstBits, adaptor.getIndices(), - stridedMetadata, rewriter); + auto stridedMetadata = + rewriter.create(loc, op.getBase()); + + OpFoldResult linearizedIndices; + std::tie(std::ignore, linearizedIndices) = + memref::getLinearizedMemRefOffsetAndSize( + rewriter, loc, srcBits, dstBits, + stridedMetadata.getConstifiedMixedOffset(), + stridedMetadata.getConstifiedMixedSizes(), + stridedMetadata.getConstifiedMixedStrides(), + getAsOpFoldResult(adaptor.getIndices())); auto srcElementType = sourceType.getElementType(); auto numElements = static_cast(std::ceil(static_cast(origElements) / scale)); auto newLoad = rewriter.create( - loc, VectorType::get(numElements, srcElementType), reinterpret, - linearizedOffset); + loc, VectorType::get(numElements, srcElementType), adaptor.getBase(), + getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); numElements *= scale; auto castType = VectorType::get(numElements, oldElementType); diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir deleted file mode 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir +++ /dev/null @@ -1,107 +0,0 @@ -// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=4 memref-load-bitwidth=8" %s | FileCheck %s - -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)> -// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> - -// Expect no conversions, i32 is supported. -// CHECK-LABEL: func @memref_i32 -// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi32, 1> -// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi32, 1> -// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi32, 1> -// CHECK-NEXT: return -func.func @memref_i32() { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : i32 - %m = memref.alloc() : memref<4xi32, 1> - %v = memref.load %m[%c0] : memref<4xi32, 1> - memref.store %c1, %m[%c0] : memref<4xi32, 1> - return -} - -// ----- - -// Expect no conversions, f32 is not an integer type. -// CHECK-LABEL: func @memref_f32 -// CHECK: [[M:%.+]] = memref.alloc() : memref<4xf32, 1> -// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xf32, 1> -// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xf32, 1> -// CHECK-NEXT: return -func.func @memref_f32() { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1.0 : f32 - %m = memref.alloc() : memref<4xf32, 1> - %v = memref.load %m[%c0] : memref<4xf32, 1> - memref.store %c1, %m[%c0] : memref<4xf32, 1> - return -} - -// ----- - -// CHECK-LABEL: func @memref_load_i4_zero_rank -// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref -// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[M]] : memref -> memref, index -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[M]][] : memref -// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[OFFSET]] : index to i8 -// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 -// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 -// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 -// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 -// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 -// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4 -// CHECK-NEXT: return -func.func @memref_load_i4_zero_rank() { - %0 = memref.alloc() : memref - %1 = memref.load %0[] : memref - return -} - -// ----- - -// CHECK-LABEL: func @memref_load_i4 -// CHECK-SAME: (%[[ARG:.*]]: index) -// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4xi8> -// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref, index, index, index -// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]] -// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] -// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref to memref<4xi8, strided<[1], offset: ?>> -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[1], offset: ?>> -// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8 -// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 -// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 -// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 -// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 -// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 -// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4 -// CHECK-NEXT: return -func.func @memref_load_i4(%arg0: index) { - %0 = memref.alloc() : memref<4xi4> - %1 = memref.load %0[%arg0] : memref<4xi4> - return -} - -// ----- - -// CHECK-LABEL: func @memref_load_i4_rank2 -// CHECK-SAME: (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -// CHECK-NEXT: memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8> -// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref, index, index, index, index, index -// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] -// CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1] -// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] -// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref to memref<512xi8, strided<[1], offset: ?>> -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[1], offset: ?>> -// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8 -// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 -// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 -// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 -// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 -// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 -// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4 -// CHECK-NEXT: return -func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) { - memref.assume_alignment %0, 64 : memref<4x128xi4> - %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4> - return -} diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir deleted file mode 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir +++ /dev/null @@ -1,72 +0,0 @@ -// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8 memref-load-bitwidth=8" %s | FileCheck %s - -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)> -// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> - -// Expect no conversions. -// CHECK-LABEL: func @memref_i8 -// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi8, 1> -// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi8, 1> -// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi8, 1> -// CHECK-NEXT: return -func.func @memref_i8() { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : i8 - %m = memref.alloc() : memref<4xi8, 1> - %v = memref.load %m[%c0] : memref<4xi8, 1> - memref.store %c1, %m[%c0] : memref<4xi8, 1> - return -} - -// ----- - -// CHECK-LABEL: func @memref_load_i4 -// CHECK-SAME: (%[[ARG:.*]]: index) -// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4xi8> -// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref, index, index, index -// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]] -// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] -// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref to memref<4xi8, strided<[1], offset: ?>> -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[1], offset: ?>> -// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8 -// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 -// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 -// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 -// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 -// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 -// CHECK-NEXT: %[[MASK:.*]] = arith.constant 15 : i8 -// CHECK-NEXT: %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8 -// CHECK-NEXT: return -func.func @memref_load_i4(%arg0: index) { - %0 = memref.alloc() : memref<4xi4> - %1 = memref.load %0[%arg0] : memref<4xi4> - return -} - -// ----- - -// CHECK-LABEL: func @memref_load_i4_rank2 -// CHECK-SAME: (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -// CHECK-NEXT: memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8> -// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref, index, index, index, index, index -// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] -// CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1] -// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] -// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref to memref<512xi8, strided<[1], offset: ?>> -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[1], offset: ?>> -// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8 -// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 -// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 -// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 -// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 -// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 -// CHECK-NEXT: %[[MASK:.*]] = arith.constant 15 : i8 -// CHECK-NEXT: %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8 -// CHECK-NEXT: return -func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) { - memref.assume_alignment %0, 64 : memref<4x128xi4> - %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4> - return -} diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -0,0 +1,138 @@ +// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s +// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32 + +// Expect no conversions. +func.func @memref_i8() -> i8 { + %c3 = arith.constant 3 : index + %m = memref.alloc() : memref<4xi8, 1> + %v = memref.load %m[%c3] : memref<4xi8, 1> + return %v : i8 +} +// CHECK-LABEL: func @memref_i8() +// CHECK: %[[M:.+]] = memref.alloc() : memref<4xi8, 1> +// CHECK-NEXT: %[[V:.+]] = memref.load %[[M]][%{{.+}}] : memref<4xi8, 1> +// CHECK-NEXT: return %[[V]] + +// CHECK32-LABEL: func @memref_i8() +// CHECK32: %[[M:.+]] = memref.alloc() : memref<1xi32, 1> +// CHECK32: %[[C0:.+]] = arith.constant 0 : index +// CHECK32: %[[V:.+]] = memref.load %[[M]][%[[C0]]] : memref<1xi32, 1> +// CHECK32: %[[C24:.+]] = arith.constant 24 : index +// CHECK32: %[[CAST:.+]] = arith.index_cast %[[C24]] : index to i32 +// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[V]], %[[CAST]] +// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i8 +// CHECK32-NEXT: return %[[TRUNC]] + +// ----- + +func.func @memref_load_i4(%arg0: index) -> i4 { + %0 = memref.alloc() : memref<5xi4> + %1 = memref.load %0[%arg0] : memref<5xi4> + return %1 : i4 +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8) +// CHECK: func @memref_load_i4( +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8> +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]] +// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] +// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]] +// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8 +// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]] +// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4 +// CHECK: return %[[TRUNC]] + +// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)> +// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32) +// CHECK32: func @memref_load_i4( +// CHECK32-SAME: %[[ARG0:.+]]: index +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32> +// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]] +// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] +// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]] +// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32 +// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]] +// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4 +// CHECK32: return %[[TRUNC]] + +// ----- + +func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 { + %0 = memref.alloc() : memref<3x125xi4> + memref.assume_alignment %0, 64 : memref<3x125xi4> + %1 = memref.load %0[%arg0,%arg1] : memref<3x125xi4> + return %1 : i4 +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8) +// CHECK: func @memref_load_i4_rank2( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8> +// CHECK: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8> +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]] +// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] +// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]] +// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8 +// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] +// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4 +// CHECK: return %[[TRUNC]] + +// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)> +// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32) +// CHECK32: func @memref_load_i4_rank2( +// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32> +// CHECK32: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32> +// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]] +// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] +// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]] +// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32 +// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] +// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4 +// CHECK32: return %[[TRUNC]] + +// ----- + +func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %arg3 : index) -> i4 { + %0 = memref.alloc(%arg0, %arg1) : memref + %1 = memref.load %0[%arg2, %arg3] : memref + return %1 : i4 +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)> +// CHECK: func @memref_load_i4_dynamic( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]] +// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]] +// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] +// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]] +// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8 +// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] +// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4 +// CHECK: return %[[TRUNC]] + +// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)> +// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)> +// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)> +// CHECK32: func @memref_load_i4_dynamic( +// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index +// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index +// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index +// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]] +// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) +// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]] +// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]] +// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]] +// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32 +// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] +// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4 +// CHECK32: return %[[TRUNC]] diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir @@ -1,36 +1,81 @@ -// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=4 memref-load-bitwidth=8" %s | FileCheck %s - -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)> -// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> +// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s +// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32 +func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> { + %0 = memref.alloc() : memref<3x4xi8> + %1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8> + return %1 : vector<4xi8> +} // Expect no conversions, i8 is supported. -// CHECK-LABEL: func @vector_load_i8 -// CHECK-SAME: (%[[ARG:.*]]: memref<3x4xi8>, %[[IDX0:.*]]: index, %[[IDX1:.*]]: index) -// CHECK-NEXT: [[L:%.+]] = vector.load %[[ARG]][%[[IDX0]], %[[IDX1]]] : memref<3x4xi8>, vector<4xi8> -// CHECK-NEXT: return -func.func @vector_load_i8(%arg0: memref<3x4xi8>, %arg1: index, %arg2: index) { - %0 = vector.load %arg0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8> - return +// CHECK: func @vector_load_i8( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8> +// CHECK-NEXT: [[L:%.+]] = vector.load %[[ALLOC]][%[[ARG0]], %[[ARG1]]] : memref<3x4xi8>, vector<4xi8> +// CHECK-NEXT: return + +// CHECK32: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)> +// CHECK32: func @vector_load_i8( +// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32> +// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]] +// CHECK32: %[[VECLOAD:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32> +// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VECLOAD]] : vector<1xi32> to vector<4xi8> +// CHECK32: return %[[VEC_I4]] + +// ----- + +func.func @vector_load_i4(%arg1: index, %arg2: index) -> vector<3x8xi4> { + %0 = memref.alloc() : memref<3x8xi4> + %cst = arith.constant dense<0> : vector<3x8xi4> + %1 = vector.load %0[%arg1, %arg2] : memref<3x8xi4>, vector<8xi4> + %2 = vector.insert %1, %cst [0] : vector<8xi4> into vector<3x8xi4> + return %2 : vector<3x8xi4> } +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)> +// CHECK: func @vector_load_i4 +// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8> +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]] +// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<12xi8>, vector<4xi8> +// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4> + +// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)> +// CHECK32: func @vector_load_i4 +// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32> +// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]] +// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32> +// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4> // ----- -// CHECK-LABEL: func @vector_load_i4 -// CHECK-SAME: (%[[ARG:.*]]: memref<3x4xi8>, %[[IDX0:.*]]: index, %[[IDX1:.*]]: index) -// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<0> : vector<3x4xi4> -// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x4xi8> -> memref, index, index, index, index, index -// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[IDX0]], %[[STRIDES]]#0, %[[IDX1]], %[[STRIDES]]#1] -// CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1] -// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] -// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref to memref<12xi8, strided<[1], offset: ?>> -// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[CAST]][%[[INDEX]]] : memref<12xi8, strided<[1], offset: ?>>, vector<2xi8> -// CHECK-NEXT: %[[BITCAST:.*]] = vector.bitcast %[[LOAD]] : vector<2xi8> to vector<4xi4> -// CHECK-NEXT: %[[INSERT:.*]] = vector.insert %[[BITCAST]], %[[CST]] [0] : vector<4xi4> into vector<3x4xi4> -// CHECK-NEXT: return -func.func @vector_load_i4(%arg0: memref<3x4xi4>, %arg1: index, %arg2: index) { - %cst = arith.constant dense<0> : vector<3x4xi4> - %0 = vector.load %arg0[%arg1, %arg2] : memref<3x4xi4>, vector<4xi4> - %1 = vector.insert %0, %cst [0] : vector<4xi4> into vector<3x4xi4> - return +func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> vector<8xi4> { + %0 = memref.alloc(%arg0, %arg1) : memref + %1 = vector.load %0[%arg2, %arg3] : memref, vector<8xi4> + return %1 : vector<8xi4> } +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)> +// CHECK: func.func @vector_load_i4_dynamic( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index +// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]] +// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]] +// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref, vector<4xi8> +// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4> + +// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)> +// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)> +// CHECK32: func.func @vector_load_i4_dynamic( +// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index +// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]] +// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref +// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]] +// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref, vector<1xi32> +// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4> diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp --- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp +++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp @@ -89,8 +89,7 @@ target.addDynamicallyLegalOp(opLegalCallback); target.addDynamicallyLegalDialect< arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect, - affine::AffineDialect>( - [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); + affine::AffineDialect>(opLegalCallback); RewritePatternSet patterns(ctx);