diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -39,7 +39,7 @@ /// baseBuffer, baseOffset, baseSizes, baseStrides = /// extract_strided_metadata(memref) /// strides#i = baseStrides#i * subSizes#i -/// offset = baseOffset + sum(subOffset#i * strides#i) +/// offset = baseOffset + sum(subOffset#i * baseStrides#i) /// sizes = subSizes /// \endverbatim /// @@ -83,8 +83,8 @@ auto origStrides = newExtractStridedMetadata.getStrides(); // Hold the affine symbols and values for the computation of the offset. - SmallVector values(3 * sourceRank + 1); - SmallVector symbols(3 * sourceRank + 1); + SmallVector values(2 * sourceRank + 1); + SmallVector symbols(2 * sourceRank + 1); detail::bindSymbolsList(rewriter.getContext(), symbols); AffineExpr expr = symbols.front(); @@ -105,14 +105,11 @@ rewriter, origLoc, s0 * s1, {subStrides[i], origStride})); // Build up the computation of the offset. - unsigned baseIdxForDim = 1 + 3 * i; + unsigned baseIdxForDim = 1 + 2 * i; unsigned subOffsetForDim = baseIdxForDim; - unsigned subStrideForDim = baseIdxForDim + 1; - unsigned origStrideForDim = baseIdxForDim + 2; - expr = expr + symbols[subOffsetForDim] * symbols[subStrideForDim] * - symbols[origStrideForDim]; + unsigned origStrideForDim = baseIdxForDim + 1; + expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim]; values[subOffsetForDim] = subOffsets[i]; - values[subStrideForDim] = subStrides[i]; values[origStrideForDim] = origStride; } diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -24,7 +24,7 @@ // Check that we simplify extract_strided_metadata of subview to // base_buf, base_offset, base_sizes, base_strides = extract_strided_metadata // strides = base_stride_i * subview_stride_i -// offset = base_offset + sum(subview_offsets_i * strides_i). +// offset = base_offset + sum(subview_offsets_i * base_strides_i). // // This test also checks that we don't create useless arith operations // when subview_offsets_i is 0. @@ -42,8 +42,8 @@ // // Final offset is: // origOffset + (== 0) -// base_stride0 * subview_stride0 * subview_offset0 + (== 4 * 1 * 0 == 0) -// base_stride1 * subview_stride1 * subview_offset1 (== 1 * 1 * 2) +// base_stride0 * subview_offset0 + (== 4 * 0 == 0) +// base_stride1 * subview_offset1 (== 1 * 2) // == 2 // // Return the new tuple. @@ -171,14 +171,14 @@ // // Orig offset: 0 // Sub offsets: [3, 4, 2] -// => Final offset: 3 * 64 + 4 * 4 * %stride + 2 * 1 + 0 == 16 * %stride + 194 +// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210 // // CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)> -// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0] -> (s0 * 16 + 194)> // CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides // CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>, // CHECK-SAME: %[[DYN_STRIDE:.*]]: index) // +// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index @@ -186,9 +186,8 @@ // CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] // // CHECK-DAG: %[[DIM1_STRIDE:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_STRIDE]]] -// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[DYN_STRIDE]]] // -// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]] +// CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]] func.func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides( %base: memref<8x16x4xf32>, %stride: index) -> (memref, index, index, index, index, index) { @@ -262,11 +261,11 @@ // // Orig offset: origOff // Sub offsets: [subO0, subO1, subO2] -// => Final offset: s0 * subS0 * subO0 + ... + s2 * subS2 * subO2 + origOff -// ==> 1 affine map with (rank * 3 + 1) symbols +// => Final offset: s0 * * subO0 + ... + s2 * subO2 + origOff +// ==> 1 affine map with (rank * 2 + 1) symbols // // CHECK-DAG: #[[$STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> -// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0 + (s1 * s2) * s3 + (s4 * s5) * s6 + (s7 * s8) * s9)> +// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)> // CHECK-LABEL: func @extract_strided_metadata_of_subview_all_dynamic // CHECK-SAME: (%[[ARG:.*]]: memref>, %[[DYN_OFFSET0:.*]]: index, %[[DYN_OFFSET1:.*]]: index, %[[DYN_OFFSET2:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index, %[[DYN_STRIDE2:.*]]: index) // @@ -276,7 +275,7 @@ // CHECK-DAG: %[[FINAL_STRIDE1:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE1]], %[[STRIDES]]#1] // CHECK-DAG: %[[FINAL_STRIDE2:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE2]], %[[STRIDES]]#2] // -// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[DYN_STRIDE0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[DYN_STRIDE1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[DYN_STRIDE2]], %[[STRIDES]]#2] +// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[STRIDES]]#2] // // CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]], %[[FINAL_STRIDE0]], %[[FINAL_STRIDE1]], %[[FINAL_STRIDE2]] func.func @extract_strided_metadata_of_subview_all_dynamic(