diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/NarrowIntEmulationConverter.h b/mlir/include/mlir/Dialect/Arith/Transforms/NarrowIntEmulationConverter.h --- a/mlir/include/mlir/Dialect/Arith/Transforms/NarrowIntEmulationConverter.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/NarrowIntEmulationConverter.h @@ -19,8 +19,6 @@ class NarrowIntEmulationConverter : public TypeConverter { public: explicit NarrowIntEmulationConverter(unsigned targetWideInt); - -private: unsigned targetBitwidth; }; } // namespace mlir::arith diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowInt.cpp --- a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowInt.cpp @@ -34,38 +34,10 @@ : targetBitwidth(targetWideInt) { assert(llvm::isPowerOf2_32(targetWideInt) && "Only power-of-two integers are supported"); - assert(targetWideInt >= 8 && "Target integer type too narrow"); // Allow unknown types. addConversion([](Type ty) -> std::optional { return ty; }); - // Scalar case. - addConversion([this](IntegerType ty) -> std::optional { - unsigned width = ty.getWidth(); - if (width >= targetBitwidth) - return ty; - else - return IntegerType::get(ty.getContext(), targetBitwidth); - - return std::nullopt; - }); - - // Vector case. - addConversion([this](VectorType ty) -> std::optional { - auto intTy = dyn_cast(ty.getElementType()); - if (!intTy) - return ty; - - unsigned width = intTy.getWidth(); - if (width >= targetBitwidth) - return ty; - else - return VectorType::get(to_vector(ty.getShape()), - IntegerType::get(ty.getContext(), targetBitwidth)); - - return std::nullopt; - }); - // Function case. addConversion([this](FunctionType ty) -> std::optional { SmallVector inputs; diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp @@ -86,16 +86,37 @@ LogicalResult matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type newResTy = getTypeConverter()->convertType(op.getType()); - if (!newResTy) + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to convert memref type: {0}", op.getMemRefType())); + if (op.getMemRefType() == cast(newTy)) + return success(); auto loc = op.getLoc(); Value source = adaptor.getMemref(); auto sourceType = cast(source.getType()); + auto srcElementType = sourceType.getElementType(); unsigned sourceRank = sourceType.getRank(); + + auto oldElementType = + cast(op.getMemref().getType()).getElementType(); + int srcBits = oldElementType.getIntOrFloatBitWidth(); + int dstBits = srcElementType.getIntOrFloatBitWidth(); + assert(dstBits % srcBits == 0); + + // The emulation only works on 1D memref types. To make this work on N-D + // memref, we need to linearize the offset. + // Specifically, %0 = memref.load %0[%v0][%v1] : + // memref> can be replaced with + // %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 + // %linearized_offset = %v0 * %stride#0 + %scaled_v1 * %stride#1 + // where %scaled_v1 = v1 / targetBits * sourceBits + // %linearized_size = %size0 * %size1 + // %linearized = memref.reinterpret_cast %b, offset = [%offset], sizes = + // [%linearized_size], strides = [%stride#1] %load = memref.load + // %linearized[%linearized_offset] : memref> auto stridedMetadata = rewriter.create(loc, source); auto baseBuffer = stridedMetadata.getBaseBuffer(); @@ -103,12 +124,6 @@ auto baseStrides = stridedMetadata.getStrides(); auto baseOffset = stridedMetadata.getOffset(); - int srcBits = cast(op.getMemref().getType()) - .getElementType() - .getIntOrFloatBitWidth(); - int dstBits = sourceType.getElementType().getIntOrFloatBitWidth(); - assert(dstBits % srcBits == 0); - SmallVector indices = adaptor.getIndices(); assert(indices.size() == baseStrides.size()); assert(indices.size() == sourceRank); @@ -127,49 +142,45 @@ } // Linearize offset and sizes. - Value linearized_offset = adjustOffsets[0]; - Value linearized_size = baseSizes[0]; + Value linearizedOffset = adjustOffsets[0]; + Value linearizedSize = baseSizes[0]; for (unsigned i = 1; i < sourceRank; ++i) { - linearized_offset = rewriter.create(loc, linearized_offset, - adjustOffsets[i]); - linearized_size = - rewriter.create(loc, linearized_size, baseSizes[i]); + linearizedOffset = rewriter.create(loc, linearizedOffset, + adjustOffsets[i]); + linearizedSize = + rewriter.create(loc, linearizedSize, baseSizes[i]); } // Flatten n-D MemRef to 1-D MemRef. - StridedLayoutAttr layoutAttr = StridedLayoutAttr::get( + auto layoutAttr = StridedLayoutAttr::get( sourceType.getContext(), ShapedType::kDynamic, {ShapedType::kDynamic}); int64_t staticShape = sourceType.hasStaticShape() ? sourceType.getNumElements() : ShapedType::kDynamic; - auto flattenMemrefType = - MemRefType::get(staticShape, sourceType.getElementType(), layoutAttr, - sourceType.getMemorySpace()); + auto flattenMemrefType = MemRefType::get( + staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace()); auto reinterpret = rewriter.create( - loc, flattenMemrefType, baseBuffer, baseOffset, linearized_size, + loc, flattenMemrefType, baseBuffer, baseOffset, linearizedSize, baseStrides.back()); - auto newLoad = - rewriter.create(loc, newResTy, reinterpret.getResult(), - linearized_offset, op.getNontemporal()); + auto newLoad = rewriter.create( + loc, srcElementType, reinterpret.getResult(), linearizedOffset, + op.getNontemporal()); // Get the offset and shift the bits to the rightmost. auto lastIdx = rewriter.create( - loc, sourceType.getElementType(), adaptor.getIndices().back()); + loc, srcElementType, adaptor.getIndices().back()); Value BitwidthOffset = getOffsetForBitwidth(loc, lastIdx, srcBits, dstBits, rewriter); auto bitsLoad = rewriter.create(loc, newLoad, BitwidthOffset); - // Apply the mask to extract corresponding bits. - auto mask = rewriter.create( - loc, sourceType.getElementType(), - rewriter.getIntegerAttr(sourceType.getElementType(), - (1 << srcBits) - 1)); - auto result = rewriter.create(loc, bitsLoad, mask); - + // Get the low bits by truncating the result. + auto result = + rewriter.create(loc, oldElementType, bitsLoad); rewriter.replaceOp(op, result.getResult()); + return success(); } }; @@ -196,10 +207,16 @@ if (!intTy) return ty; - Type newElemTy = typeConverter.convertType(intTy); - if (!newElemTy) - return std::nullopt; - - return ty.cloneWith(std::nullopt, newElemTy); + unsigned width = intTy.getWidth(); + if (width >= typeConverter.targetBitwidth) + return ty; + else { + Type newElemTy = + IntegerType::get(ty.getContext(), typeConverter.targetBitwidth, + intTy.getSignedness()); + if (!newElemTy) + return std::nullopt; + return ty.cloneWith(std::nullopt, newElemTy); + } }); } diff --git a/mlir/test/Dialect/Arith/emulate-narrow-int.mlir b/mlir/test/Dialect/Arith/emulate-narrow-int.mlir --- a/mlir/test/Dialect/Arith/emulate-narrow-int.mlir +++ b/mlir/test/Dialect/Arith/emulate-narrow-int.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --test-emulate-narrow-int="target-wide-int=8" %s | FileCheck %s +// RUN: mlir-opt --test-emulate-narrow-int="int4-arith-bitwidth=8" %s | FileCheck %s // Expect no conversions, f32 is not an integer type. // CHECK-LABEL: func @identity_f32 diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-int.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-int.mlir --- a/mlir/test/Dialect/MemRef/emulate-narrow-int.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-int.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --test-emulate-narrow-int="target-wide-int=8" %s | FileCheck %s +// RUN: mlir-opt --test-emulate-narrow-int="int4-arith-bitwidth=4 memref-target-bits=8" %s | FileCheck %s // Expect no conversions, i32 is supported. // CHECK-LABEL: func @memref_i32 @@ -45,8 +45,7 @@ // CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 // CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 // CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 -// CHECK-NEXT: %[[MASK:.*]] = arith.constant 15 : i8 -// CHECK-NEXT: %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4 // CHECK-NEXT: return func.func @memref_load_i4(%arg0: index) { %0 = memref.alloc() : memref<4xi4> @@ -73,11 +72,10 @@ // CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 // CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 // CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 -// CHECK-NEXT: %[[MASK:.*]] = arith.constant 15 : i8 -// CHECK-NEXT: %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4 // CHECK-NEXT: return func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) { %0 = memref.alloc() : memref<4x4xi4> %1 = memref.load %0[%arg0,%arg1] : memref<4x4xi4> return -} \ No newline at end of file +} diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowInt.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowInt.cpp --- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowInt.cpp +++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowInt.cpp @@ -20,6 +20,7 @@ using namespace mlir; namespace { + struct TestEmulateNarrowIntPass : public PassWrapper> { @@ -48,6 +49,35 @@ MLIRContext *ctx = op->getContext(); arith::NarrowIntEmulationConverter typeConverter(targetWideInt); + + // Convert scalar type. + typeConverter.addConversion([this](IntegerType ty) -> std::optional { + unsigned width = ty.getWidth(); + if (width >= arithBitwidth) + return ty; + else + return IntegerType::get(ty.getContext(), arithBitwidth); + + return std::nullopt; + }); + + // Convert vector type. + typeConverter.addConversion([this](VectorType ty) -> std::optional { + auto intTy = dyn_cast(ty.getElementType()); + if (!intTy) + return ty; + + unsigned width = intTy.getWidth(); + if (width >= arithBitwidth) + return ty; + else + return VectorType::get( + to_vector(ty.getShape()), + IntegerType::get(ty.getContext(), arithBitwidth)); + + return std::nullopt; + }); + memref::populateMemRefNarrowIntEmulationConversions(typeConverter); ConversionTarget target(*ctx); target.addDynamicallyLegalOp([&typeConverter](Operation *op) { @@ -63,7 +93,6 @@ RewritePatternSet patterns(ctx); - // Add common pattenrs to support contants, functions, etc. arith::populateArithNarrowIntEmulationPatterns(typeConverter, patterns); memref::populateMemRefNarrowIntEmulationPatterns(typeConverter, patterns); @@ -71,9 +100,13 @@ signalPassFailure(); } - Option targetWideInt{*this, "target-wide-int", + Option targetWideInt{*this, "memref-target-bits", llvm::cl::desc("Target integer bit width"), llvm::cl::init(8)}; + + Option arithBitwidth{*this, "int4-arith-bitwidth", + llvm::cl::desc("Target integer bit width"), + llvm::cl::init(4)}; }; } // namespace