diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h --- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h +++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h @@ -16,6 +16,8 @@ #ifndef MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H #define MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H +#include "mlir/Dialect/MemRef/IR/MemRef.h" + namespace mlir { class MemRefType; @@ -26,6 +28,37 @@ /// contiguous chunk of memory. bool isStaticShapeAndContiguousRowMajor(MemRefType type); +/// Returns the flattened 1-D memref and linearized offset for narrow type +/// emulation. +/// +/// The emulation only works on 1D memref types. To make this work on N-D +/// memref, we need to linearize the offset. +/// +/// For example, to emulate i4 to i8, the following op: +/// +/// %0 = memref.load %arg0[%v0, %v1] : +/// memref> +/// +/// can be replaced with +/// +/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 +/// +/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1 +/// %linearized_size = %size0 * %size1 +/// %scaled_linear_offset = %linearized_offset / 8 * 4 +/// %scaled_base_offset = %offset / 8 * 4 +/// +/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset], +/// sizes = [%linearized_size], strides = [%stride#1] +/// +/// %new_load = memref.load %linearized[%scaled_linear_offset] : +/// memref> +std::pair +getLinearizeMemRefAndOffset(Location loc, MemRefType sourceType, int srcBits, + int dstBits, SmallVector indices, + memref::ExtractStridedMetadataOp stridedMetadata, + OpBuilder &builder); + } // namespace memref } // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -11,7 +11,6 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -9,7 +9,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/FormatVariadic.h" @@ -27,102 +28,6 @@ // Utility functions //===----------------------------------------------------------------------===// -/// The emulation only works on 1D memref types. -/// To make this work on N-D memref, we need to linearize the offset. -/// -/// For example, to emulate i4 to i8, the following op: -/// -/// %0 = memref.load %arg0[%v0, %v1] : -/// memref> -/// -/// can be replaced with -/// -/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 -/// -/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1 -/// %linearized_size = %size0 * %size1 -/// %scaled_linear_offset = %linearized_offset / 8 * 4 -/// %scaled_base_offset = %offset / 8 * 4 -/// -/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset], -/// sizes = [%linearized_size], strides = [%stride#1] -/// -/// %new_load = memref.load %linearized[%scaled_linear_offset] : -/// memref> - -static Value -linearizeMemrefLoad(Location loc, MemRefType sourceType, int srcBits, - int dstBits, SmallVector indices, - memref::ExtractStridedMetadataOp stridedMetadata, - OpBuilder &builder) { - auto srcElementType = sourceType.getElementType(); - unsigned sourceRank = indices.size(); - - Value baseBuffer = stridedMetadata.getBaseBuffer(); - SmallVector baseSizes = stridedMetadata.getSizes(); - SmallVector baseStrides = stridedMetadata.getStrides(); - Value baseOffset = stridedMetadata.getOffset(); - assert(indices.size() == baseStrides.size()); - - // Create the affine symbols and values for linearization. - SmallVector symbols(2 * sourceRank + 2); - bindSymbolsList(builder.getContext(), MutableArrayRef{symbols}); - symbols[0] = builder.getAffineSymbolExpr(0); - AffineExpr addMulMap = symbols.front(); - AffineExpr mulMap = symbols.front(); - - SmallVector offsetValues(2 * sourceRank + 2); - offsetValues[0] = builder.getIndexAttr(0); - SmallVector sizeValues(sourceRank + 1); - sizeValues[0] = builder.getIndexAttr(1); - - for (unsigned i = 0; i < sourceRank; ++i) { - unsigned offsetIdx = 2 * i + 1; - addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1]; - offsetValues[offsetIdx] = indices[i]; - offsetValues[offsetIdx + 1] = baseStrides[i]; - - unsigned sizeIdx = i + 1; - mulMap = mulMap * symbols[sizeIdx]; - sizeValues[sizeIdx] = baseSizes[i]; - } - - // Adjust linearizedOffset by the scale factor (dstBits / srcBits). - OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits); - AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back()); - offsetValues.back() = scaler; - - OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply( - builder, loc, scaledAddMulMap, offsetValues); - OpFoldResult linearizedSize = - affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues); - - // Adjust baseOffset by the scale factor (dstBits / srcBits). - AffineExpr s0, s1; - bindSymbols(builder.getContext(), s0, s1); - OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply( - builder, loc, s0.floorDiv(s1), {baseOffset, scaler}); - - // Flatten n-D MemRef to 1-D MemRef. - auto layoutAttr = StridedLayoutAttr::get( - sourceType.getContext(), ShapedType::kDynamic, {ShapedType::kDynamic}); - int64_t staticShape = sourceType.hasStaticShape() - ? sourceType.getNumElements() - : ShapedType::kDynamic; - auto flattenMemrefType = MemRefType::get( - staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace()); - - auto reinterpret = builder.create( - loc, flattenMemrefType, baseBuffer, - getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset), - getValueOrCreateConstantIndexOp(builder, loc, linearizedSize), - baseStrides.back()); - - return builder.create( - loc, srcElementType, reinterpret.getResult(), - getValueOrCreateConstantIndexOp(builder, loc, linearizedOffset)); -} - /// When data is loaded/stored in `targetBits` granularity, but is used in /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is /// treated as an array of elements of width `sourceBits`. @@ -239,8 +144,13 @@ lastIdx = stridedMetadata.getOffset(); } else { - newLoad = linearizeMemrefLoad(loc, sourceType, srcBits, dstBits, indices, - stridedMetadata, rewriter); + auto [reinterpret, linearizedOffset] = + memref::getLinearizeMemRefAndOffset(loc, sourceType, srcBits, dstBits, + adaptor.getIndices(), + stridedMetadata, rewriter); + + newLoad = rewriter.create(loc, srcElementType, + reinterpret, linearizedOffset); lastIdx = adaptor.getIndices().back(); } diff --git a/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt @@ -6,4 +6,6 @@ LINK_LIBS PUBLIC MLIRIR + MLIRAffineDialect + MLIRArithUtils ) diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -11,6 +11,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" namespace mlir { @@ -44,5 +46,80 @@ return curDim < 0; } +std::pair +getLinearizeMemRefAndOffset(Location loc, MemRefType sourceType, int srcBits, + int dstBits, SmallVector indices, + memref::ExtractStridedMetadataOp stridedMetadata, + OpBuilder &builder) { + auto srcElementType = sourceType.getElementType(); + unsigned sourceRank = indices.size(); + + Value baseBuffer = stridedMetadata.getBaseBuffer(); + SmallVector baseSizes = stridedMetadata.getSizes(); + SmallVector baseStrides = stridedMetadata.getStrides(); + Value baseOffset = stridedMetadata.getOffset(); + assert(indices.size() == baseStrides.size()); + + // Create the affine symbols and values for linearization. + SmallVector symbols(2 * sourceRank + 2); + bindSymbolsList(builder.getContext(), MutableArrayRef{symbols}); + symbols[0] = builder.getAffineSymbolExpr(0); + AffineExpr addMulMap = symbols.front(); + AffineExpr mulMap = symbols.front(); + + SmallVector offsetValues(2 * sourceRank + 2); + offsetValues[0] = builder.getIndexAttr(0); + SmallVector sizeValues(sourceRank + 1); + sizeValues[0] = builder.getIndexAttr(1); + + for (unsigned i = 0; i < sourceRank; ++i) { + unsigned offsetIdx = 2 * i + 1; + addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1]; + offsetValues[offsetIdx] = indices[i]; + offsetValues[offsetIdx + 1] = baseStrides[i]; + + unsigned sizeIdx = i + 1; + mulMap = mulMap * symbols[sizeIdx]; + sizeValues[sizeIdx] = baseSizes[i]; + } + + // Adjust linearizedOffset by the scale factor (dstBits / srcBits). + OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits); + AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back()); + offsetValues.back() = scaler; + + OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply( + builder, loc, scaledAddMulMap, offsetValues); + OpFoldResult linearizedSize = + affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues); + + // Adjust baseOffset by the scale factor (dstBits / srcBits). + AffineExpr s0, s1; + bindSymbols(builder.getContext(), s0, s1); + OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply( + builder, loc, s0.floorDiv(s1), {baseOffset, scaler}); + + // Flatten n-D MemRef to 1-D MemRef. + std::optional stride = + getConstantIntValue(stridedMetadata.getConstifiedMixedStrides().back()); + auto layoutAttr = + StridedLayoutAttr::get(sourceType.getContext(), ShapedType::kDynamic, + {stride ? stride.value() : ShapedType::kDynamic}); + int64_t staticShape = sourceType.hasStaticShape() + ? sourceType.getNumElements() + : ShapedType::kDynamic; + auto flattenMemrefType = MemRefType::get( + staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace()); + + auto reinterpret = builder.create( + loc, flattenMemrefType, baseBuffer, + getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset), + getValueOrCreateConstantIndexOp(builder, loc, linearizedSize), + baseStrides.back()); + + return std::make_pair(reinterpret, getValueOrCreateConstantIndexOp( + builder, loc, linearizedOffset)); +} + } // namespace memref } // namespace mlir diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -37,6 +37,7 @@ MLIRIR MLIRLinalgDialect MLIRMemRefDialect + MLIRMemRefUtils MLIRSCFDialect MLIRSideEffectInterfaces MLIRTensorDialect diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Transforms/DialectConversion.h" @@ -21,107 +22,6 @@ using namespace mlir; -//===----------------------------------------------------------------------===// -// Utility functions -//===----------------------------------------------------------------------===// - -/// The emulation only works on 1D memref types. -/// To make this work on N-D memref, we need to linearize the offset. -/// -/// For example, to emulate i4 to i8, the following op: -/// -/// %0 = memref.load %arg0[%v0, %v1] : -/// memref> -/// -/// can be replaced with -/// -/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 -/// -/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1 -/// %linearized_size = %size0 * %size1 -/// %scaled_linear_offset = %linearized_offset / 8 * 4 -/// %scaled_base_offset = %offset / 8 * 4 -/// -/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset], -/// sizes = [%linearized_size], strides = [%stride#1] -/// -/// %new_load = vector.load %linearized[%scaled_linear_offset] : -/// memref> - -static Value -linearizeVectorLoad(Location loc, MemRefType sourceType, int srcBits, - int dstBits, SmallVector indices, - memref::ExtractStridedMetadataOp stridedMetadata, - int numElements, OpBuilder &builder) { - auto srcElementType = sourceType.getElementType(); - unsigned sourceRank = indices.size(); - - Value baseBuffer = stridedMetadata.getBaseBuffer(); - SmallVector baseSizes = stridedMetadata.getSizes(); - SmallVector baseStrides = stridedMetadata.getStrides(); - Value baseOffset = stridedMetadata.getOffset(); - assert(indices.size() == baseStrides.size()); - - // Create the affine symbols and values for linearization. - SmallVector symbols(2 * sourceRank + 2); - bindSymbolsList(builder.getContext(), MutableArrayRef{symbols}); - symbols[0] = builder.getAffineSymbolExpr(0); - AffineExpr addMulMap = symbols.front(); - AffineExpr mulMap = symbols.front(); - - SmallVector offsetValues(2 * sourceRank + 2); - offsetValues[0] = builder.getIndexAttr(0); - SmallVector sizeValues(sourceRank + 1); - sizeValues[0] = builder.getIndexAttr(1); - - for (unsigned i = 0; i < sourceRank; ++i) { - unsigned offsetIdx = 2 * i + 1; - addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1]; - offsetValues[offsetIdx] = indices[i]; - offsetValues[offsetIdx + 1] = baseStrides[i]; - - unsigned sizeIdx = i + 1; - mulMap = mulMap * symbols[sizeIdx]; - sizeValues[sizeIdx] = baseSizes[i]; - } - - // Adjust linearizedOffset by the scale factor (dstBits / srcBits). - OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits); - AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back()); - offsetValues.back() = scaler; - - OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply( - builder, loc, scaledAddMulMap, offsetValues); - OpFoldResult linearizedSize = - affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues); - - // Adjust baseOffset by the scale factor (dstBits / srcBits). - AffineExpr s0, s1; - bindSymbols(builder.getContext(), s0, s1); - OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply( - builder, loc, s0.floorDiv(s1), {baseOffset, scaler}); - - // Flatten n-D MemRef to 1-D MemRef. - auto layoutAttr = StridedLayoutAttr::get(sourceType.getContext(), - ShapedType::kDynamic, {1}); - int64_t staticShape = sourceType.hasStaticShape() - ? sourceType.getNumElements() - : ShapedType::kDynamic; - auto flattenMemrefType = MemRefType::get( - staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace()); - - auto reinterpret = builder.create( - loc, flattenMemrefType, baseBuffer, - getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset), - getValueOrCreateConstantIndexOp(builder, loc, linearizedSize), - baseStrides.back()); - - return builder.create( - loc, VectorType::get(numElements, srcElementType), - reinterpret.getResult(), - getValueOrCreateConstantIndexOp(builder, loc, linearizedOffset)); -} - namespace { //===----------------------------------------------------------------------===// @@ -172,10 +72,16 @@ auto stridedMetadata = rewriter.create( loc, adaptor.getBase()); - auto numElements = int(std::ceil(double(origElements) / scale)); - auto newLoad = linearizeVectorLoad(loc, sourceType, srcBits, dstBits, - adaptor.getIndices(), stridedMetadata, - numElements, rewriter); + auto [reinterpret, linearizedOffset] = memref::getLinearizeMemRefAndOffset( + loc, sourceType, srcBits, dstBits, adaptor.getIndices(), + stridedMetadata, rewriter); + + auto srcElementType = sourceType.getElementType(); + auto numElements = + static_cast(std::ceil(static_cast(origElements) / scale)); + auto newLoad = rewriter.create( + loc, VectorType::get(numElements, srcElementType), reinterpret, + linearizedOffset); numElements *= scale; auto castType = VectorType::get(numElements, oldElementType); diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir --- a/mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir @@ -65,8 +65,8 @@ // CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref, index, index, index // CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]] // CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] -// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref to memref<4xi8, strided<[?], offset: ?>> -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref to memref<4xi8, strided<[1], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[1], offset: ?>> // CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8 // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 // CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 @@ -90,8 +90,8 @@ // CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] // CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1] // CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] -// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref to memref<512xi8, strided<[?], offset: ?>> -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref to memref<512xi8, strided<[1], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[1], offset: ?>> // CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8 // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 // CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir --- a/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir @@ -28,8 +28,8 @@ // CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref, index, index, index // CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]] // CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] -// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref to memref<4xi8, strided<[?], offset: ?>> -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref to memref<4xi8, strided<[1], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[1], offset: ?>> // CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8 // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 // CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 @@ -54,8 +54,8 @@ // CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] // CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1] // CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] -// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref to memref<512xi8, strided<[?], offset: ?>> -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref to memref<512xi8, strided<[1], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[1], offset: ?>> // CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8 // CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 // CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4015,6 +4015,7 @@ ":IR", ":LinalgDialect", ":MemRefDialect", + ":MemRefUtils", ":Pass", ":SCFDialect", ":SideEffectInterfaces", @@ -6892,6 +6893,7 @@ ":LLVMCommonConversion", ":LLVMDialect", ":MemRefDialect", + ":MemRefUtils", ":Pass", ":Support", ":Transforms", @@ -11052,14 +11054,12 @@ [ "lib/Dialect/MemRef/IR/*.cpp", "lib/Dialect/MemRef/IR/*.h", - "lib/Dialect/MemRef/Utils/*.cpp", ], ), hdrs = [ "include/mlir/Dialect/MemRef/IR/MemRef.h", "include/mlir/Dialect/MemRef/IR/MemRefMemorySlot.h", "include/mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h", - "include/mlir/Dialect/MemRef/Utils/MemRefUtils.h", ], includes = ["include"], deps = [ @@ -11084,6 +11084,24 @@ ], ) +cc_library( + name = "MemRefUtils", + srcs = glob( + [ + "lib/Dialect/MemRef/Utils/*.cpp", + ], + ), + hdrs = [ + "include/mlir/Dialect/MemRef/Utils/MemRefUtils.h", + ], + includes = ["include"], + deps = [ + ":AffineDialect", + ":ArithUtils", + ":MemRefDialect", + ], +) + gentbl_cc_library( name = "MemRefPassIncGen", strip_include_prefix = "include", @@ -11128,6 +11146,7 @@ ":LoopLikeInterface", ":MemRefDialect", ":MemRefPassIncGen", + ":MemRefUtils", ":NVGPUDialect", ":Pass", ":RuntimeVerifiableOpInterface", @@ -11562,6 +11581,7 @@ ":IR", ":LoopLikeInterface", ":MemRefDialect", + ":MemRefUtils", ":Pass", ":SideEffectInterfaces", ":TensorDialect",