diff --git a/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h b/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Affine/ViewLikeInterfaceUtils.h @@ -0,0 +1,49 @@ +//===- ViewLikeInterfaceUtils.h - View-like operations interface utilities-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_AFFINE_INDEXUTILS_H +#define MLIR_DIALECT_AFFINE_INDEXUTILS_H + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/ViewLikeInterface.h" + +namespace mlir { + +/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use +/// when combining a producer slice **into** a consumer slice. +/// +/// This function performs the following computation: +/// - Combined offsets = producer_offsets * consumer_strides + consumer_offsets +/// - Combined sizes = consumer_sizes +/// - Combined strides = producer_strides * consumer_strides +LogicalResult +mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, + ArrayRef producerOffsets, + ArrayRef producerSizes, + ArrayRef producerStrides, + const llvm::SmallBitVector &droppedProducerDims, + ArrayRef consumerOffsets, + ArrayRef consumerSizes, + ArrayRef consumerStrides, + SmallVector &combinedOffsets, + SmallVector &combinedSizes, + SmallVector &combinedStrides); + +/// Fills the `combinedOffsets`, `combinedSizes` and `combinedStrides` to use +/// when combining a `producer` slice op **into** a `consumer` slice op. +LogicalResult +mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, + OffsetSizeAndStrideOpInterface producer, + OffsetSizeAndStrideOpInterface consumer, + const llvm::SmallBitVector &droppedProducerDims, + SmallVector &combinedOffsets, + SmallVector &combinedSizes, + SmallVector &combinedStrides); + +} // namespace mlir + +#endif // MLIR_DIALECT_AFFINE_INDEXUTILS_H diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/ViewLikeConsumerImpl.h b/mlir/include/mlir/Dialect/MemRef/Transforms/ViewLikeConsumerImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/ViewLikeConsumerImpl.h @@ -0,0 +1,21 @@ +//===- ViewLikeConsumerImpl.h - Impl. of ViewLikeConsumerOpInterface ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_VIEWLIKECONSUMEROPINTERFACEIMPL_H +#define MLIR_DIALECT_MEMREF_TRANSFORMS_VIEWLIKECONSUMEROPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace memref { +void registerViewLikeConsumerOpInterfaceExternalModels( + DialectRegistry ®istry); +} // namespace memref +} // namespace mlir + +#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_VIEWLIKECONSUMEROPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -41,6 +41,7 @@ #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/ViewLikeConsumerImpl.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" @@ -127,6 +128,7 @@ tensor::registerInferTypeOpInterfaceExternalModels(registry); tensor::registerTilingInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); + memref::registerViewLikeConsumerOpInterfaceExternalModels(registry); } /// Append all the MLIR dialects to the registry contained in the given context. diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -56,7 +56,7 @@ const SmallVectorImpl &mixedValues); class OffsetSizeAndStrideOpInterface; - +class RewriterBase; namespace detail { LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op); diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -465,4 +465,48 @@ }]; } +def ViewLikeConsumerOpInterface : OpInterface<"ViewLikeConsumerOpInterface"> { + let description = [{ + An operation that has one or more operands that can be "fold" the producer of + the operand if the producer is a ViewLikeOpInterface. An example is + `memref.load %A [offsets]`. If `%A` is the result of `memref.subview`, then + we can update the op to act on the source of the subview with appropriately + updated offsets. + }]; + + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the OpOperands that can "consume" a producing ViewLikeOpInterface. + }], + /*retTy=*/"::mlir::SmallVector<::mlir::OpOperand*>", + /*methodName=*/"getViewLikeConsumerOperands", + /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/[{ + Return the new indices for the given operand and producer op. + }], + /*retTy=*/"::mlir::FailureOr<::mlir::SmallVector<::mlir::OpFoldResult>>", + /*methodName=*/"resolveSourceIndices", + /*args=*/(ins "::mlir::OpBuilder&":$builder, "::mlir::OpOperand*":$operand, + "::mlir::Operation*":$producer) + >, + InterfaceMethod< + /*desc=*/[{ + + }], + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"replaceViewLikeConsumerOp", + /*args=*/(ins + "::mlir::RewriterBase&":$rewriter, + "::mlir::OpOperand*":$operand, + "::mlir::Value":$newOperand, + "::mlir::ArrayRef<::mlir::OpFoldResult>":$newIndices) + > + ]; +} + #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE diff --git a/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt b/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Utils/CMakeLists.txt @@ -2,6 +2,7 @@ LoopFusionUtils.cpp LoopUtils.cpp Utils.cpp + ViewLikeInterfaceUtils.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine @@ -13,4 +14,5 @@ MLIRArithmeticUtils MLIRMemRefDialect MLIRTransformUtils + MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Affine/Utils/ViewLikeInterfaceUtils.cpp @@ -0,0 +1,77 @@ +//===- ViewLikeInterfaceUtils.cpp - View-like operations interface utils --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/IR/AffineExpr.h" + +using namespace mlir; + +LogicalResult mlir::mergeOffsetsSizesAndStrides( + OpBuilder &builder, Location loc, ArrayRef producerOffsets, + ArrayRef producerSizes, + ArrayRef producerStrides, + const llvm::SmallBitVector &droppedProducerDims, + ArrayRef consumerOffsets, + ArrayRef consumerSizes, + ArrayRef consumerStrides, + SmallVector &combinedOffsets, + SmallVector &combinedSizes, + SmallVector &combinedStrides) { + combinedOffsets.resize(producerOffsets.size()); + combinedSizes.resize(producerOffsets.size()); + combinedStrides.resize(producerOffsets.size()); + + AffineExpr s0, s1, d0; + bindDims(builder.getContext(), d0); + bindSymbols(builder.getContext(), s0, s1); + + unsigned consumerPos = 0; + for (auto i : llvm::seq(0, producerOffsets.size())) { + if (droppedProducerDims.test(i)) { + // For dropped dims, get the values from the producer. + combinedOffsets[i] = producerOffsets[i]; + combinedSizes[i] = producerSizes[i]; + combinedStrides[i] = producerStrides[i]; + continue; + } + SmallVector offsetSymbols, strideSymbols; + // The combined offset is computed as + // producer_offset + consumer_offset * producer_strides. + combinedOffsets[i] = makeComposedFoldedAffineApply( + builder, loc, d0 * s0 + s1, + {consumerOffsets[consumerPos], producerStrides[i], producerOffsets[i]}); + combinedSizes[i] = consumerSizes[consumerPos]; + // The combined stride is computed as + // consumer_stride * producer_stride. + combinedStrides[i] = makeComposedFoldedAffineApply( + builder, loc, d0 * s0, + {consumerStrides[consumerPos], producerStrides[i]}); + + consumerPos++; + } + return success(); +} + +LogicalResult mlir::mergeOffsetsSizesAndStrides( + OpBuilder &builder, Location loc, OffsetSizeAndStrideOpInterface producer, + OffsetSizeAndStrideOpInterface consumer, + const llvm::SmallBitVector &droppedProducerDims, + SmallVector &combinedOffsets, + SmallVector &combinedSizes, + SmallVector &combinedStrides) { + SmallVector consumerOffsets = consumer.getMixedOffsets(); + SmallVector consumerSizes = consumer.getMixedSizes(); + SmallVector consumerStrides = consumer.getMixedStrides(); + SmallVector producerOffsets = producer.getMixedOffsets(); + SmallVector producerSizes = producer.getMixedSizes(); + SmallVector producerStrides = producer.getMixedStrides(); + return mlir::mergeOffsetsSizesAndStrides( + builder, loc, producerOffsets, producerSizes, producerStrides, + droppedProducerDims, consumerOffsets, consumerSizes, consumerStrides, + combinedOffsets, combinedSizes, combinedStrides); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ NormalizeMemRefs.cpp ResolveShapedTypeResultDims.cpp SimplifyExtractStridedMetadata.cpp + ViewLikeConsumerImpl.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -332,7 +332,7 @@ return failure(); llvm::TypeSwitch(loadOp) - .Case([&](auto op) { + .Case([&](auto op) { rewriter.replaceOpWithNewOp(loadOp, subViewOp.source(), sourceIndices); }) @@ -519,9 +519,37 @@ return success(); } +struct RewriteViewLikeMemRefConsumers + : ::mlir::OpInterfaceRewritePattern { + using OpInterfaceRewritePattern< + ViewLikeConsumerOpInterface>::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(ViewLikeConsumerOpInterface op, + PatternRewriter &rewriter) const override { + + SmallVector consumers = op.getViewLikeConsumerOperands(); + bool madeChange = false; + for (OpOperand *operand : consumers) { + if (!operand->get().getType().isa()) + continue; + auto producer = operand->get().getDefiningOp(); + if (!producer) + continue; + FailureOr> indices = + op.resolveSourceIndices(rewriter, operand, producer.getOperation()); + if (failed(indices)) + continue; + if (failed(op.replaceViewLikeConsumerOp( + rewriter, operand, producer.getViewSource(), *indices))) + continue; + madeChange = true; + } + return success(madeChange); + } +}; + void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { patterns.add, - LoadOpOfSubViewOpFolder, LoadOpOfSubViewOpFolder, StoreOpOfSubViewOpFolder, StoreOpOfSubViewOpFolder, @@ -533,8 +561,8 @@ LoadOpOfCollapseShapeOpFolder, LoadOpOfCollapseShapeOpFolder, StoreOpOfCollapseShapeOpFolder, - StoreOpOfCollapseShapeOpFolder>( - patterns.getContext()); + StoreOpOfCollapseShapeOpFolder, + RewriteViewLikeMemRefConsumers>(patterns.getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/Transforms/ViewLikeConsumerImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/ViewLikeConsumerImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/ViewLikeConsumerImpl.cpp @@ -0,0 +1,80 @@ +//===- ViewLikeConsumerImpl.cpp - Impl. of ViewLikeConsumerOpInterface ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Transforms/ViewLikeConsumerImpl.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" + +using namespace mlir; +using namespace memref; + +namespace mlir { +namespace memref { +namespace { + +/// Enables `memref.load` to skip over `memref.subview` producers. +struct LoadOpInterface + : public ViewLikeConsumerOpInterface::ExternalModel { + SmallVector getViewLikeConsumerOperands(Operation *op) const { + return {&op->getOpOperand(0)}; + } + + FailureOr> + resolveSourceIndices(Operation *op, OpBuilder &builder, + ::mlir::OpOperand *operand, + ::mlir::Operation *producer) const { + auto loadOp = cast(op); + + // Handle the case when the producer is a `memref.subview`. + if (auto ifaceOp = dyn_cast(producer)) { + auto one = builder.getIndexAttr(1); + SmallVector indices = + getAsOpFoldResult(loadOp.getIndices()); + SmallVector ones(indices.size(), one); + SmallVector combinedOffsets, combinedSizes, combinedStrides; + if (failed(mergeOffsetsSizesAndStrides( + builder, op->getLoc(), ifaceOp.getMixedOffsets(), + ifaceOp.getMixedSizes(), ifaceOp.getMixedStrides(), + ifaceOp.getDroppedDims(), indices, ones, ones, combinedOffsets, + combinedSizes, combinedStrides))) + return failure(); + return combinedOffsets; + } + + // TODO: handle the `memref.collapse_shape` and `memref.expand_shape` cases. + + return failure(); + } + + LogicalResult replaceViewLikeConsumerOp( + Operation *op, RewriterBase &rewriter, ::mlir::OpOperand *operand, + ::mlir::Value newOperand, + ::mlir::ArrayRef<::mlir::OpFoldResult> newIndices) const { + assert(operand->getOperandNumber() == 0 && + "only memref operand can be replaced"); + SmallVector newIndexVals = + getValueOrCreateConstantIndexOp(rewriter, op->getLoc(), newIndices); + rewriter.replaceOpWithNewOp(op, newOperand, newIndexVals); + return success(); + } +}; +} // namespace +} // namespace memref +} // namespace mlir + +void mlir::memref::registerViewLikeConsumerOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, MemRefDialect *dialect) { + memref::LoadOp::attachInterface(*ctx); + ctx->loadDialect(); + }); +} diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" using namespace mlir; diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -3,18 +3,18 @@ func.func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 { %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>> %1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>> - return %1 : f32 + return %1 : f32 } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 * 3 + s0)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)> // CHECK: func @fold_static_stride_subview_with_load // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index -// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP0]](%[[ARG3]])[%[[ARG1]]] -// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP1]](%[[ARG4]])[%[[ARG2]]] +// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[ARG3]]] +// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG4]]] // CHECK: memref.load %[[ARG0]][%[[I1]], %[[I2]]] // ----- @@ -25,7 +25,7 @@ %1 = memref.load %0[%arg3, %arg4] : memref<4x4xf32, strided<[?, ?], offset: ?>> return %1 : f32 } -// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)> // CHECK: func @fold_dynamic_stride_subview_with_load // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index @@ -34,8 +34,8 @@ // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: index -// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]](%[[ARG3]])[%[[ARG5]], %[[ARG1]]] -// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]](%[[ARG4]])[%[[ARG6]], %[[ARG2]]] +// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG5]], %[[ARG1]], %[[ARG3]]] +// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]]()[%[[ARG6]], %[[ARG2]], %[[ARG4]]] // CHECK: memref.load %[[ARG0]][%[[I1]], %[[I2]]] // ----- @@ -178,7 +178,7 @@ %1 = memref.load %0[%arg13, %arg14, %arg15, %arg16] : memref<4x1x4x1xf32, strided<[?, ?, ?, ?], offset: ?>> return %1 : f32 } -// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s1 + s2 * s0)> // CHECK: func @fold_rank_reducing_subview_with_load // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index @@ -197,14 +197,11 @@ // CHECK-SAME: %[[ARG14:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG15:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG16:[a-zA-Z0-9_]+]]: index -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]](%[[ARG13]])[%[[ARG7]], %[[ARG1]]] -// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP]](%[[C0]])[%[[ARG8]], %[[ARG2]]] -// CHECK-DAG: %[[I3:.+]] = affine.apply #[[MAP]](%[[ARG14]])[%[[ARG9]], %[[ARG3]]] -// CHECK-DAG: %[[I4:.+]] = affine.apply #[[MAP]](%[[ARG15]])[%[[ARG10]], %[[ARG4]]] -// CHECK-DAG: %[[I5:.+]] = affine.apply #[[MAP]](%[[ARG16]])[%[[ARG11]], %[[ARG5]]] -// CHECK-DAG: %[[I6:.+]] = affine.apply #[[MAP]](%[[C0]])[%[[ARG12]], %[[ARG6]]] -// CHECK: memref.load %[[ARG0]][%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]] +// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP]]()[%[[ARG7]], %[[ARG1]], %[[ARG13]]] +// CHECK-DAG: %[[I3:.+]] = affine.apply #[[MAP]]()[%[[ARG9]], %[[ARG3]], %[[ARG14]]] +// CHECK-DAG: %[[I4:.+]] = affine.apply #[[MAP]]()[%[[ARG10]], %[[ARG4]], %[[ARG15]]] +// CHECK-DAG: %[[I5:.+]] = affine.apply #[[MAP]]()[%[[ARG11]], %[[ARG5]], %[[ARG16]]] +// CHECK: memref.load %[[ARG0]][%[[I1]], %[[ARG2]], %[[I3]], %[[I4]], %[[I5]], %[[ARG6]]] // -----