diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -839,8 +839,8 @@ let arguments = (ins PDL_Operation:$target, Variadic:$dynamic_sizes, - DefaultValuedAttr:$static_sizes, - DefaultValuedAttr:$interchange); + DefaultValuedOptionalAttr:$static_sizes, + DefaultValuedOptionalAttr:$interchange); let results = (outs PDL_Operation:$tiled_linalg_op, Variadic:$loops); @@ -917,8 +917,8 @@ let arguments = (ins PDL_Operation:$target, Variadic:$num_threads, Variadic:$tile_sizes, - DefaultValuedAttr:$static_num_threads, - DefaultValuedAttr:$static_tile_sizes, + DefaultValuedOptionalAttr:$static_num_threads, + DefaultValuedOptionalAttr:$static_tile_sizes, OptionalAttr:$mapping); let results = (outs PDL_Operation:$foreach_thread_op, PDL_Operation:$tiled_op); @@ -1009,8 +1009,8 @@ let arguments = (ins PDL_Operation:$target, Variadic:$dynamic_sizes, - DefaultValuedAttr:$static_sizes, - DefaultValuedAttr:$interchange); + DefaultValuedOptionalAttr:$static_sizes, + DefaultValuedOptionalAttr:$interchange); let results = (outs PDL_Operation:$tiled_linalg_op, Variadic:$loops); diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1260,9 +1260,9 @@ Variadic:$offsets, Variadic:$sizes, Variadic:$strides, - I64ArrayAttr:$static_offsets, - I64ArrayAttr:$static_sizes, - I64ArrayAttr:$static_strides); + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_sizes, + DenseI64ArrayAttr:$static_strides); let results = (outs AnyMemRef:$result); let assemblyFormat = [{ @@ -1476,7 +1476,7 @@ or copies. A reassociation is defined as a grouping of dimensions and is represented - with an array of I64ArrayAttr attributes. + with an array of DenseI64ArrayAttr attributes. Example: @@ -1563,7 +1563,7 @@ type. A reassociation is defined as a continuous grouping of dimensions and is - represented with an array of I64ArrayAttr attribute. + represented with an array of DenseI64ArrayAttr attribute. Note: Only the dimensions within a reassociation group must be contiguous. The remaining dimensions may be non-contiguous. @@ -1855,9 +1855,9 @@ Variadic:$offsets, Variadic:$sizes, Variadic:$strides, - I64ArrayAttr:$static_offsets, - I64ArrayAttr:$static_sizes, - I64ArrayAttr:$static_strides); + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_sizes, + DenseI64ArrayAttr:$static_strides); let results = (outs AnyMemRef:$result); let assemblyFormat = [{ diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -326,9 +326,9 @@ Variadic:$offsets, Variadic:$sizes, Variadic:$strides, - I64ArrayAttr:$static_offsets, - I64ArrayAttr:$static_sizes, - I64ArrayAttr:$static_strides + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_sizes, + DenseI64ArrayAttr:$static_strides ); let results = (outs AnyRankedTensor:$result); @@ -807,9 +807,9 @@ Variadic:$offsets, Variadic:$sizes, Variadic:$strides, - I64ArrayAttr:$static_offsets, - I64ArrayAttr:$static_sizes, - I64ArrayAttr:$static_strides + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_sizes, + DenseI64ArrayAttr:$static_strides ); let results = (outs AnyRankedTensor:$result); @@ -1013,7 +1013,7 @@ rank whose sizes are a reassociation of the original `src`. A reassociation is defined as a continuous grouping of dimensions and is - represented with an array of I64ArrayAttr attribute. + represented with an array of DenseI64ArrayAttr attribute. The verification rule is that the reassociation maps are applied to the result tensor with the higher rank to obtain the operand tensor with the @@ -1065,7 +1065,7 @@ rank whose sizes are a reassociation of the original `src`. A reassociation is defined as a continuous grouping of dimensions and is - represented with an array of I64ArrayAttr attribute. + represented with an array of DenseI64ArrayAttr attribute. The verification rule is that the reassociation maps are applied to the operand tensor with the higher rank to obtain the result tensor with the @@ -1206,8 +1206,8 @@ AnyTensor:$source, Variadic:$low, Variadic:$high, - I64ArrayAttr:$static_low, - I64ArrayAttr:$static_high, + DenseI64ArrayAttr:$static_low, + DenseI64ArrayAttr:$static_high, UnitAttr:$nofold); let regions = (region SizedRegion<1>:$region); @@ -1254,16 +1254,17 @@ // Return a vector of all the static or dynamic values (low/high padding) of // the op. - inline SmallVector getMixedPadImpl(ArrayAttr staticAttrs, + inline SmallVector getMixedPadImpl(ArrayRef staticAttrs, ValueRange values) { + Builder builder(*this); SmallVector res; unsigned numDynamic = 0; unsigned count = staticAttrs.size(); for (unsigned idx = 0; idx < count; ++idx) { - if (ShapedType::isDynamic(staticAttrs[idx].cast().getInt())) + if (ShapedType::isDynamic(staticAttrs[idx])) res.push_back(values[numDynamic++]); else - res.push_back(staticAttrs[idx]); + res.push_back(builder.getI64IntegerAttr(staticAttrs[idx])); } return res; } @@ -1400,9 +1401,9 @@ Variadic:$offsets, Variadic:$sizes, Variadic:$strides, - I64ArrayAttr:$static_offsets, - I64ArrayAttr:$static_sizes, - I64ArrayAttr:$static_strides + DenseI64ArrayAttr:$static_offsets, + DenseI64ArrayAttr:$static_sizes, + DenseI64ArrayAttr:$static_strides ); let assemblyFormat = [{ $source `into` $dest `` @@ -1748,7 +1749,7 @@ DefaultValuedOptionalAttr:$outer_dims_perm, DenseI64ArrayAttr:$inner_dims_pos, Variadic:$inner_tiles, - I64ArrayAttr:$static_inner_tiles); + DenseI64ArrayAttr:$static_inner_tiles); let results = (outs AnyRankedTensor:$result); let assemblyFormat = [{ $source @@ -1803,7 +1804,7 @@ DefaultValuedOptionalAttr:$outer_dims_perm, DenseI64ArrayAttr:$inner_dims_pos, Variadic:$inner_tiles, - I64ArrayAttr:$static_inner_tiles); + DenseI64ArrayAttr:$static_inner_tiles); let results = (outs AnyRankedTensor:$result); let assemblyFormat = [{ $source diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -87,6 +87,18 @@ SmallVector getAsValues(OpBuilder &b, Location loc, ArrayRef valueOrAttrVec); +/// Return a vector of OpFoldResults with the same size a staticValues, but all +/// elements for which ShapedType::isDynamic is true, will be replaced by +/// dynamicValues. +SmallVector getMixedValues(ArrayRef staticValues, + ValueRange dynamicValues, Builder &b); + +/// Decompose a vector of mixed static or dynamic values into the corresponding +/// pair of arrays. This is the inverse function of `getMixedValues`. +std::pair> +decomposeMixedValues(Builder &b, + const SmallVectorImpl &mixedValues); + } // namespace mlir #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -21,18 +21,6 @@ namespace mlir { -/// Return a vector of OpFoldResults with the same size a staticValues, but all -/// elements for which ShapedType::isDynamic is true, will be replaced by -/// dynamicValues. -SmallVector getMixedValues(ArrayAttr staticValues, - ValueRange dynamicValues); - -/// Decompose a vector of mixed static or dynamic values into the corresponding -/// pair of arrays. This is the inverse function of `getMixedValues`. -std::pair> -decomposeMixedValues(Builder &b, - const SmallVectorImpl &mixedValues); - class OffsetSizeAndStrideOpInterface; namespace detail { @@ -61,7 +49,7 @@ /// idiomatic printing of mixed value and integer attributes in a list. E.g. /// `[%arg0, 7, 42, %arg42]`. void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, - OperandRange values, ArrayAttr integers); + OperandRange values, ArrayRef integers); /// Pasrer hook for custom directive in assemblyFormat. /// @@ -79,13 +67,14 @@ ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl &values, - ArrayAttr &integers); + DenseI64ArrayAttr &integers); /// Verify that a the `values` has as many elements as the number of entries in /// `attr` for which `isDynamic` evaluates to true. LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, - ArrayAttr attr, ValueRange values); + ArrayRef attr, + ValueRange values); } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -124,7 +124,7 @@ /*desc=*/[{ Return the static offset attributes. }], - /*retTy=*/"::mlir::ArrayAttr", + /*retTy=*/"::llvm::ArrayRef", /*methodName=*/"static_offsets", /*args=*/(ins), /*methodBody=*/"", @@ -136,7 +136,7 @@ /*desc=*/[{ Return the static size attributes. }], - /*retTy=*/"::mlir::ArrayAttr", + /*retTy=*/"::llvm::ArrayRef", /*methodName=*/"static_sizes", /*args=*/(ins), /*methodBody=*/"", @@ -148,7 +148,7 @@ /*desc=*/[{ Return the dynamic stride attributes. }], - /*retTy=*/"::mlir::ArrayAttr", + /*retTy=*/"::llvm::ArrayRef", /*methodName=*/"static_strides", /*args=*/(ins), /*methodBody=*/"", @@ -165,8 +165,9 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ + Builder b($_op->getContext()); return ::mlir::getMixedValues($_op.getStaticOffsets(), - $_op.getOffsets()); + $_op.getOffsets(), b); }] >, InterfaceMethod< @@ -178,7 +179,8 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::mlir::getMixedValues($_op.getStaticSizes(), $_op.sizes()); + Builder b($_op->getContext()); + return ::mlir::getMixedValues($_op.getStaticSizes(), $_op.sizes(), b); }] >, InterfaceMethod< @@ -190,8 +192,9 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ + Builder b($_op->getContext()); return ::mlir::getMixedValues($_op.getStaticStrides(), - $_op.getStrides()); + $_op.getStrides(), b); }] >, @@ -202,9 +205,7 @@ /*args=*/(ins "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ - ::llvm::APInt v = *(static_offsets() - .template getAsValueRange<::mlir::IntegerAttr>().begin() + idx); - return ::mlir::ShapedType::isDynamic(v.getSExtValue()); + return ::mlir::ShapedType::isDynamic(static_offsets()[idx]); }] >, InterfaceMethod< @@ -214,9 +215,7 @@ /*args=*/(ins "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ - ::llvm::APInt v = *(static_sizes() - .template getAsValueRange<::mlir::IntegerAttr>().begin() + idx); - return ::mlir::ShapedType::isDynamic(v.getSExtValue()); + return ::mlir::ShapedType::isDynamic(static_sizes()[idx]); }] >, InterfaceMethod< @@ -226,9 +225,7 @@ /*args=*/(ins "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ - ::llvm::APInt v = *(static_strides() - .template getAsValueRange<::mlir::IntegerAttr>().begin() + idx); - return ::mlir::ShapedType::isDynamic(v.getSExtValue()); + return ::mlir::ShapedType::isDynamic(static_strides()[idx]); }] >, InterfaceMethod< @@ -241,9 +238,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ assert(!$_op.isDynamicOffset(idx) && "expected static offset"); - ::llvm::APInt v = *(static_offsets(). - template getAsValueRange<::mlir::IntegerAttr>().begin() + idx); - return v.getSExtValue(); + return static_offsets()[idx]; }] >, InterfaceMethod< @@ -256,9 +251,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ assert(!$_op.isDynamicSize(idx) && "expected static size"); - ::llvm::APInt v = *(static_sizes(). - template getAsValueRange<::mlir::IntegerAttr>().begin() + idx); - return v.getSExtValue(); + return static_sizes()[idx]; }] >, InterfaceMethod< @@ -271,9 +264,7 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ assert(!$_op.isDynamicStride(idx) && "expected static stride"); - ::llvm::APInt v = *(static_strides(). - template getAsValueRange<::mlir::IntegerAttr>().begin() + idx); - return v.getSExtValue(); + return static_strides()[idx]; }] >, @@ -289,7 +280,7 @@ /*defaultImplementation=*/[{ assert($_op.isDynamicOffset(idx) && "expected dynamic offset"); auto numDynamic = getNumDynamicEntriesUpToIdx( - static_offsets().template cast<::mlir::ArrayAttr>(), + static_offsets(), ::mlir::ShapedType::isDynamic, idx); return $_op.getOffsetSizeAndStrideStartOperandIndex() + numDynamic; @@ -307,7 +298,7 @@ /*defaultImplementation=*/[{ assert($_op.isDynamicSize(idx) && "expected dynamic size"); auto numDynamic = getNumDynamicEntriesUpToIdx( - static_sizes().template cast<::mlir::ArrayAttr>(), ::mlir::ShapedType::isDynamic, idx); + static_sizes(), ::mlir::ShapedType::isDynamic, idx); return $_op.getOffsetSizeAndStrideStartOperandIndex() + offsets().size() + numDynamic; }] @@ -324,7 +315,7 @@ /*defaultImplementation=*/[{ assert($_op.isDynamicStride(idx) && "expected dynamic stride"); auto numDynamic = getNumDynamicEntriesUpToIdx( - static_strides().template cast<::mlir::ArrayAttr>(), + static_strides(), ::mlir::ShapedType::isDynamic, idx); return $_op.getOffsetSizeAndStrideStartOperandIndex() + @@ -333,20 +324,20 @@ >, InterfaceMethod< /*desc=*/[{ - Helper method to compute the number of dynamic entries of `attr`, up to + Helper method to compute the number of dynamic entries of `staticVals`, up to `idx` using `isDynamic` to determine whether an entry is dynamic. }], /*retTy=*/"unsigned", /*methodName=*/"getNumDynamicEntriesUpToIdx", - /*args=*/(ins "::mlir::ArrayAttr":$attr, + /*args=*/(ins "::llvm::ArrayRef":$staticVals, "::llvm::function_ref":$isDynamic, "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ return std::count_if( - attr.getValue().begin(), attr.getValue().begin() + idx, - [&](::mlir::Attribute attr) { - return isDynamic(attr.cast<::mlir::IntegerAttr>().getInt()); + staticVals.begin(), staticVals.begin() + idx, + [&](int64_t val) { + return isDynamic(val); }); }] >, diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1705,10 +1705,8 @@ auto viewMemRefType = subViewOp.getType(); auto inferredType = memref::SubViewOp::inferResultType( - subViewOp.getSourceType(), - extractFromI64ArrayAttr(subViewOp.getStaticOffsets()), - extractFromI64ArrayAttr(subViewOp.getStaticSizes()), - extractFromI64ArrayAttr(subViewOp.getStaticStrides())) + subViewOp.getSourceType(), subViewOp.getStaticOffsets(), + subViewOp.getStaticSizes(), subViewOp.getStaticStrides()) .cast(); auto targetElementTy = typeConverter->convertType(viewMemRefType.getElementType()); diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -30,8 +30,8 @@ PatternRewriter &rewriter) const final { Location loc = sliceOp.getLoc(); Value input = sliceOp.getInput(); - SmallVector strides, sizes; - auto starts = sliceOp.getStart(); + SmallVector strides, sizes, starts; + starts = extractFromI64ArrayAttr(sliceOp.getStart()); strides.resize(sliceOp.getType().template cast().getRank(), 1); SmallVector dynSizes; @@ -44,15 +44,15 @@ auto dim = rewriter.create(loc, input, index); auto offset = rewriter.create( - loc, - rewriter.getIndexAttr(starts[index].cast().getInt())); + loc, rewriter.getIndexAttr(starts[index])); dynSizes.push_back(rewriter.create(loc, dim, offset)); } auto newSliceOp = rewriter.create( sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, - ValueRange({}), starts, rewriter.getI64ArrayAttr(sizes), - rewriter.getI64ArrayAttr(strides)); + ValueRange({}), rewriter.getDenseI64ArrayAttr(starts), + rewriter.getDenseI64ArrayAttr(sizes), + rewriter.getDenseI64ArrayAttr(strides)); rewriter.replaceOp(sliceOp, newSliceOp.getResult()); return success(); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -40,16 +40,6 @@ return result; } -/// Extracts a vector of int64_t from an array attribute. Asserts if the -/// attribute contains values other than integers. -static SmallVector extractI64Array(ArrayAttr attr) { - SmallVector result; - result.reserve(attr.size()); - for (APInt value : attr.getAsValueRange()) - result.push_back(value.getSExtValue()); - return result; -} - namespace { /// A simple pattern rewriter that implements no special logic. class SimpleRewriter : public PatternRewriter { @@ -1205,7 +1195,7 @@ DiagnosedSilenceableFailure transform::TileOp::apply(TransformResults &transformResults, TransformState &state) { - SmallVector tileSizes = extractFromI64ArrayAttr(getStaticSizes()); + ArrayRef tileSizes = getStaticSizes(); ArrayRef targets = state.getPayloadOps(getTarget()); SmallVector> dynamicSizeProducers; @@ -1270,7 +1260,7 @@ }); } - tilingOptions.setInterchange(extractI64Array(getInterchange())); + tilingOptions.setInterchange(getInterchange()); SimpleRewriter rewriter(linalgOp.getContext()); FailureOr maybeTilingResult = tileUsingSCFForOp( rewriter, cast(linalgOp.getOperation()), @@ -1298,7 +1288,7 @@ SmallVector transform::TileOp::getMixedSizes() { ValueRange dynamic = getDynamicSizes(); - SmallVector tileSizes = extractFromI64ArrayAttr(getStaticSizes()); + ArrayRef tileSizes = getStaticSizes(); SmallVector results; results.reserve(tileSizes.size()); unsigned dynamicPos = 0; @@ -1313,22 +1303,51 @@ return results; } +// We want to parse `DenseI64ArrayAttr` using the short form without the +// `array` prefix to be consistent in the IR with `parseDynamicIndexList`. +ParseResult parseOptionalInterchange(OpAsmParser &parser, + OperationState &result) { + if (succeeded(parser.parseOptionalLBrace())) { + if (failed(parser.parseKeyword("interchange"))) + return parser.emitError(parser.getNameLoc()) << "expect `interchange`"; + if (failed(parser.parseEqual())) + return parser.emitError(parser.getNameLoc()) << "expect `=`"; + result.addAttribute("interchange", + DenseI64ArrayAttr::parse(parser, Type{})); + if (failed(parser.parseRBrace())) + return parser.emitError(parser.getNameLoc()) << "expect `}`"; + } + return success(); +} + +void printOptionalInterchange(OpAsmPrinter &p, + ArrayRef interchangeVals) { + if (!interchangeVals.empty()) { + p << " {interchange = ["; + llvm::interleaveComma(interchangeVals, p, + [&](int64_t integer) { p << integer; }); + p << "]}"; + } +} + ParseResult transform::TileOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::UnresolvedOperand target; SmallVector dynamicSizes; - ArrayAttr staticSizes; + DenseI64ArrayAttr staticSizes; auto pdlOperationType = pdl::OperationType::get(parser.getContext()); if (parser.parseOperand(target) || parser.resolveOperand(target, pdlOperationType, result.operands) || parseDynamicIndexList(parser, dynamicSizes, staticSizes) || - parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || - parser.parseOptionalAttrDict(result.attributes)) + parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands)) return ParseResult::failure(); + // Parse optional interchange. + if (failed(parseOptionalInterchange(parser, result))) + return ParseResult::failure(); result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); size_t numExpectedLoops = - staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0); + staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0); result.addTypes(SmallVector(numExpectedLoops + 1, pdlOperationType)); return success(); } @@ -1336,7 +1355,7 @@ void TileOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes()); - p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); + printOptionalInterchange(p, getInterchange()); } void transform::TileOp::getEffects( @@ -1379,13 +1398,13 @@ // bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); - auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes); + auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); build(builder, result, /*resultTypes=*/TypeRange{operationType, operationType}, /*target=*/target, /*num_threads=*/ValueRange{}, /*tile_sizes=*/dynamicTileSizes, - /*static_num_threads=*/builder.getI64ArrayAttr({}), + /*static_num_threads=*/builder.getDenseI64ArrayAttr({}), /*static_tile_sizes=*/staticTileSizesAttr, /*mapping=*/mapping); } @@ -1414,14 +1433,14 @@ // bugs ensue. MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); - auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads); + auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads); build(builder, result, /*resultTypes=*/TypeRange{operationType, operationType}, /*target=*/target, /*num_threads=*/dynamicNumThreads, /*tile_sizes=*/ValueRange{}, /*static_num_threads=*/staticNumThreadsAttr, - /*static_tile_sizes=*/builder.getI64ArrayAttr({}), + /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}), /*mapping=*/mapping); } @@ -1547,11 +1566,13 @@ } SmallVector TileToForeachThreadOp::getMixedNumThreads() { - return getMixedValues(getStaticNumThreads(), getNumThreads()); + Builder b(getContext()); + return getMixedValues(getStaticNumThreads(), getNumThreads(), b); } SmallVector TileToForeachThreadOp::getMixedTileSizes() { - return getMixedValues(getStaticTileSizes(), getTileSizes()); + Builder b(getContext()); + return getMixedValues(getStaticTileSizes(), getTileSizes(), b); } LogicalResult TileToForeachThreadOp::verify() { @@ -1567,7 +1588,7 @@ DiagnosedSilenceableFailure transform::TileToScfForOp::apply(TransformResults &transformResults, TransformState &state) { - SmallVector tileSizes = extractFromI64ArrayAttr(getStaticSizes()); + ArrayRef tileSizes = getStaticSizes(); ArrayRef targets = state.getPayloadOps(getTarget()); SmallVector> dynamicSizeProducers; @@ -1632,7 +1653,7 @@ }); } - tilingOptions.setInterchange(extractI64Array(getInterchange())); + tilingOptions.setInterchange(getInterchange()); SimpleRewriter rewriter(tilingInterfaceOp.getContext()); FailureOr tilingResult = tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions); @@ -1655,7 +1676,7 @@ SmallVector transform::TileToScfForOp::getMixedSizes() { ValueRange dynamic = getDynamicSizes(); - SmallVector tileSizes = extractFromI64ArrayAttr(getStaticSizes()); + ArrayRef tileSizes = getStaticSizes(); SmallVector results; results.reserve(tileSizes.size()); unsigned dynamicPos = 0; @@ -1674,18 +1695,20 @@ OperationState &result) { OpAsmParser::UnresolvedOperand target; SmallVector dynamicSizes; - ArrayAttr staticSizes; + DenseI64ArrayAttr staticSizes; auto pdlOperationType = pdl::OperationType::get(parser.getContext()); if (parser.parseOperand(target) || parser.resolveOperand(target, pdlOperationType, result.operands) || parseDynamicIndexList(parser, dynamicSizes, staticSizes) || - parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || - parser.parseOptionalAttrDict(result.attributes)) + parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands)) return ParseResult::failure(); + // Parse optional interchange. + if (failed(parseOptionalInterchange(parser, result))) + return ParseResult::failure(); result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); size_t numExpectedLoops = - staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0); + staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0); result.addTypes(SmallVector(numExpectedLoops + 1, pdlOperationType)); return success(); } @@ -1693,7 +1716,7 @@ void TileToScfForOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes()); - p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); + printOptionalInterchange(p, getInterchange()); } void transform::TileToScfForOp::getEffects( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -348,7 +348,7 @@ SmallVector outputExprs; for (unsigned i = 0; i < resultShapedType.getRank(); ++i) { outputExprs.push_back(getAffineDimExpr(i, rewriter.getContext()) + - padOp.getStaticLow()[i].cast().getInt()); + padOp.getStaticLow()[i]); } SmallVector transferMaps = { diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1776,8 +1776,9 @@ dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, ShapedType::kDynamic); build(b, result, resultType, source, dynamicOffsets, dynamicSizes, - dynamicStrides, b.getI64ArrayAttr(staticOffsets), - b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); + dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticSizes), + b.getDenseI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } @@ -1823,8 +1824,8 @@ << srcType << " and result memref type " << resultType; // Match sizes in result memref type and in static_sizes attribute. - for (auto &en : llvm::enumerate(llvm::zip( - resultType.getShape(), extractFromI64ArrayAttr(getStaticSizes())))) { + for (auto &en : + llvm::enumerate(llvm::zip(resultType.getShape(), getStaticSizes()))) { int64_t resultSize = std::get<0>(en.value()); int64_t expectedSize = std::get<1>(en.value()); if (!ShapedType::isDynamic(resultSize) && @@ -1844,7 +1845,7 @@ << resultType; // Match offset in result memref type and in static_offsets attribute. - int64_t expectedOffset = extractFromI64ArrayAttr(getStaticOffsets()).front(); + int64_t expectedOffset = getStaticOffsets().front(); if (!ShapedType::isDynamic(resultOffset) && !ShapedType::isDynamic(expectedOffset) && resultOffset != expectedOffset) @@ -1852,8 +1853,8 @@ << resultOffset << " instead of " << expectedOffset; // Match strides in result memref type and in static_strides attribute. - for (auto &en : llvm::enumerate(llvm::zip( - resultStrides, extractFromI64ArrayAttr(getStaticStrides())))) { + for (auto &en : + llvm::enumerate(llvm::zip(resultStrides, getStaticStrides()))) { int64_t resultStride = std::get<0>(en.value()); int64_t expectedStride = std::get<1>(en.value()); if (!ShapedType::isDynamic(resultStride) && @@ -2665,8 +2666,9 @@ .cast(); } build(b, result, resultType, source, dynamicOffsets, dynamicSizes, - dynamicStrides, b.getI64ArrayAttr(staticOffsets), - b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); + dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticSizes), + b.getDenseI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } @@ -2831,9 +2833,7 @@ // Verify result type against inferred type. auto expectedType = SubViewOp::inferResultType( - baseType, extractFromI64ArrayAttr(getStaticOffsets()), - extractFromI64ArrayAttr(getStaticSizes()), - extractFromI64ArrayAttr(getStaticStrides())); + baseType, getStaticOffsets(), getStaticSizes(), getStaticStrides()); auto result = isRankReducedMemRefType(expectedType.cast(), subViewType, getMixedSizes()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -45,9 +45,8 @@ builder.setInsertionPoint(subviewUse); Type newType = memref::SubViewOp::inferRankReducedResultType( subviewUse.getType().getShape(), val.getType().cast(), - extractFromI64ArrayAttr(subviewUse.getStaticOffsets()), - extractFromI64ArrayAttr(subviewUse.getStaticSizes()), - extractFromI64ArrayAttr(subviewUse.getStaticStrides())); + subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), + subviewUse.getStaticStrides()); Value newSubview = builder.create( subviewUse->getLoc(), newType.cast(), val, subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -337,8 +337,7 @@ SmallVector sizes = extractOperand.getMixedSizes(); auto dimMask = computeRankReductionMask( - extractFromI64ArrayAttr(extractOperand.getStaticSizes()), - extractOperand.getType().getShape()); + extractOperand.getStaticSizes(), extractOperand.getType().getShape()); size_t dimIndex = 0; for (size_t i = 0, e = sizes.size(); i < e; i++) { if (dimMask && dimMask->count(i)) @@ -1713,8 +1712,9 @@ .cast(); } build(b, result, resultType, source, dynamicOffsets, dynamicSizes, - dynamicStrides, b.getI64ArrayAttr(staticOffsets), - b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); + dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticSizes), + b.getDenseI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } @@ -1949,13 +1949,13 @@ return failure(); // Check if there are any dynamic parts, which are not supported. - auto offsets = extractFromI64ArrayAttr(op.getStaticOffsets()); + auto offsets = op.getStaticOffsets(); if (llvm::is_contained(offsets, ShapedType::kDynamic)) return failure(); - auto sizes = extractFromI64ArrayAttr(op.getStaticSizes()); + auto sizes = op.getStaticSizes(); if (llvm::is_contained(sizes, ShapedType::kDynamic)) return failure(); - auto strides = extractFromI64ArrayAttr(op.getStaticStrides()); + auto strides = op.getStaticStrides(); if (llvm::is_contained(strides, ShapedType::kDynamic)) return failure(); @@ -2124,8 +2124,9 @@ dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, ShapedType::kDynamic); build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes, - dynamicStrides, b.getI64ArrayAttr(staticOffsets), - b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); + dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticSizes), + b.getDenseI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } @@ -2153,17 +2154,14 @@ /// Rank-reducing type verification for both InsertSliceOp and /// ParallelInsertSliceOp. -static SliceVerificationResult -verifyInsertSliceOp(ShapedType srcType, ShapedType dstType, - ArrayAttr staticOffsets, ArrayAttr staticSizes, - ArrayAttr staticStrides, - ShapedType *expectedType = nullptr) { +static SliceVerificationResult verifyInsertSliceOp( + ShapedType srcType, ShapedType dstType, ArrayRef staticOffsets, + ArrayRef staticSizes, ArrayRef staticStrides, + ShapedType *expectedType = nullptr) { // insert_slice is the inverse of extract_slice, use the same type // inference. RankedTensorType expected = ExtractSliceOp::inferResultType( - dstType, extractFromI64ArrayAttr(staticOffsets), - extractFromI64ArrayAttr(staticSizes), - extractFromI64ArrayAttr(staticStrides)); + dstType, staticOffsets, staticSizes, staticStrides); if (expectedType) *expectedType = expected; return isRankReducedType(expected, srcType); @@ -2482,9 +2480,8 @@ LogicalResult PadOp::verify() { auto sourceType = getSource().getType().cast(); auto resultType = getResult().getType().cast(); - auto expectedType = PadOp::inferResultType( - sourceType, extractFromI64ArrayAttr(getStaticLow()), - extractFromI64ArrayAttr(getStaticHigh())); + auto expectedType = + PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh()); for (int i = 0, e = sourceType.getRank(); i < e; ++i) { if (resultType.getDimSize(i) == expectedType.getDimSize(i)) continue; @@ -2556,8 +2553,9 @@ ArrayRef attrs) { auto sourceType = source.getType().cast(); auto resultType = inferResultType(sourceType, staticLow, staticHigh); - build(b, result, resultType, source, low, high, b.getI64ArrayAttr(staticLow), - b.getI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr()); + build(b, result, resultType, source, low, high, + b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh), + nofold ? b.getUnitAttr() : UnitAttr()); result.addAttributes(attrs); } @@ -2591,7 +2589,7 @@ } assert(resultType.isa()); build(b, result, resultType, source, dynamicLow, dynamicHigh, - b.getI64ArrayAttr(staticLow), b.getI64ArrayAttr(staticHigh), + b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh), nofold ? b.getUnitAttr() : UnitAttr()); result.addAttributes(attrs); } @@ -2658,8 +2656,7 @@ auto newResultType = PadOp::inferResultType( castOp.getSource().getType().cast(), - extractFromI64ArrayAttr(padTensorOp.getStaticLow()), - extractFromI64ArrayAttr(padTensorOp.getStaticHigh()), + padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(), padTensorOp.getResultType().getShape()); if (newResultType == padTensorOp.getResultType()) { @@ -2940,8 +2937,9 @@ dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides, ShapedType::kDynamic); build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes, - dynamicStrides, b.getI64ArrayAttr(staticOffsets), - b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); + dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets), + b.getDenseI64ArrayAttr(staticSizes), + b.getDenseI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } @@ -3086,12 +3084,12 @@ static SmallVector getMixedTilesImpl(OpTy op) { static_assert(llvm::is_one_of::value, "applies to only pack or unpack operations"); + Builder builder(op); SmallVector mixedInnerTiles; unsigned dynamicValIndex = 0; - for (Attribute attr : op.getStaticInnerTiles()) { - auto tileAttr = attr.cast(); - if (!ShapedType::isDynamic(tileAttr.getInt())) - mixedInnerTiles.push_back(tileAttr); + for (int64_t staticTile : op.getStaticInnerTiles()) { + if (!ShapedType::isDynamic(staticTile)) + mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile)); else mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]); } diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -137,4 +137,41 @@ return getValueOrCreateConstantIndexOp(b, loc, value); })); } + +/// Return a vector of OpFoldResults with the same size a staticValues, but all +/// elements for which ShapedType::isDynamic is true, will be replaced by +/// dynamicValues. +SmallVector getMixedValues(ArrayRef staticValues, + ValueRange dynamicValues, Builder &b) { + SmallVector res; + res.reserve(staticValues.size()); + unsigned numDynamic = 0; + unsigned count = static_cast(staticValues.size()); + for (unsigned idx = 0; idx < count; ++idx) { + int64_t value = staticValues[idx]; + res.push_back(ShapedType::isDynamic(value) + ? OpFoldResult{dynamicValues[numDynamic++]} + : OpFoldResult{b.getI64IntegerAttr(staticValues[idx])}); + } + return res; +} + +/// Decompose a vector of mixed static or dynamic values into the corresponding +/// pair of arrays. This is the inverse function of `getMixedValues`. +std::pair> +decomposeMixedValues(Builder &b, + const SmallVectorImpl &mixedValues) { + SmallVector staticValues; + SmallVector dynamicValues; + for (const auto &it : mixedValues) { + if (it.is()) { + staticValues.push_back(it.get().cast().getInt()); + } else { + staticValues.push_back(ShapedType::kDynamic); + dynamicValues.push_back(it.get()); + } + } + return {b.getI64ArrayAttr(staticValues), dynamicValues}; +} + } // namespace mlir diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -20,15 +20,15 @@ LogicalResult mlir::verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned numElements, - ArrayAttr attr, + ArrayRef staticVals, ValueRange values) { - /// Check static and dynamic offsets/sizes/strides does not overflow type. - if (attr.size() != numElements) + // Check static and dynamic offsets/sizes/strides does not overflow type. + if (staticVals.size() != numElements) return op->emitError("expected ") << numElements << " " << name << " values"; unsigned expectedNumDynamicEntries = - llvm::count_if(attr.getValue(), [&](Attribute attr) { - return ShapedType::isDynamic(attr.cast().getInt()); + llvm::count_if(staticVals, [&](int64_t staticVal) { + return ShapedType::isDynamic(staticVal); }); if (values.size() != expectedNumDynamicEntries) return op->emitError("expected ") @@ -70,19 +70,19 @@ } void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, - OperandRange values, ArrayAttr integers) { + OperandRange values, + ArrayRef integers) { printer << '['; if (integers.empty()) { printer << "]"; return; } unsigned idx = 0; - llvm::interleaveComma(integers, printer, [&](Attribute a) { - int64_t val = a.cast().getInt(); - if (ShapedType::isDynamic(val)) + llvm::interleaveComma(integers, printer, [&](int64_t integer) { + if (ShapedType::isDynamic(integer)) printer << values[idx++]; else - printer << val; + printer << integer; }); printer << ']'; } @@ -90,28 +90,28 @@ ParseResult mlir::parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, - ArrayAttr &integers) { + DenseI64ArrayAttr &integers) { if (failed(parser.parseLSquare())) return failure(); // 0-D. if (succeeded(parser.parseOptionalRSquare())) { - integers = parser.getBuilder().getArrayAttr({}); + integers = parser.getBuilder().getDenseI64ArrayAttr({}); return success(); } - SmallVector attrVals; + SmallVector integerVals; while (true) { OpAsmParser::UnresolvedOperand operand; auto res = parser.parseOptionalOperand(operand); if (res.has_value() && succeeded(res.value())) { values.push_back(operand); - attrVals.push_back(ShapedType::kDynamic); + integerVals.push_back(ShapedType::kDynamic); } else { - IntegerAttr attr; - if (failed(parser.parseAttribute(attr))) + int64_t integer; + if (failed(parser.parseInteger(integer))) return parser.emitError(parser.getNameLoc()) << "expected SSA value or integer"; - attrVals.push_back(attr.getInt()); + integerVals.push_back(integer); } if (succeeded(parser.parseOptionalComma())) @@ -120,7 +120,7 @@ return failure(); break; } - integers = parser.getBuilder().getI64ArrayAttr(attrVals); + integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals); return success(); } @@ -144,34 +144,3 @@ return false; return true; } - -SmallVector mlir::getMixedValues(ArrayAttr staticValues, - ValueRange dynamicValues) { - SmallVector res; - res.reserve(staticValues.size()); - unsigned numDynamic = 0; - unsigned count = static_cast(staticValues.size()); - for (unsigned idx = 0; idx < count; ++idx) { - APInt value = staticValues[idx].cast().getValue(); - res.push_back(ShapedType::isDynamic(value.getSExtValue()) - ? OpFoldResult{dynamicValues[numDynamic++]} - : OpFoldResult{staticValues[idx]}); - } - return res; -} - -std::pair> -mlir::decomposeMixedValues(Builder &b, - const SmallVectorImpl &mixedValues) { - SmallVector staticValues; - SmallVector dynamicValues; - for (const auto &it : mixedValues) { - if (it.is()) { - staticValues.push_back(it.get().cast().getInt()); - } else { - staticValues.push_back(ShapedType::kDynamic); - dynamicValues.push_back(it.get()); - } - } - return {b.getI64ArrayAttr(staticValues), dynamicValues}; -} diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -49,6 +49,15 @@ return ArrayAttr.get([_get_int64_attr(v) for v in values]) +def _get_dense_int64_array_attr( + values: Sequence[int]) -> DenseI64ArrayAttr: + """Creates a dense integer array from a sequence of integers. + Expects the thread-local MLIR context to have been set by the context + manager. + """ + if values is None: + return DenseI64ArrayAttr.get([]) + return DenseI64ArrayAttr.get(values) def _get_int_int_array_attr( values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, @@ -250,14 +259,11 @@ else: for size in sizes: if isinstance(size, int): - static_sizes.append(IntegerAttr.get(i64_type, size)) - elif isinstance(size, IntegerAttr): static_sizes.append(size) else: - static_sizes.append( - IntegerAttr.get(i64_type, ShapedType.get_dynamic_size())) + static_sizes.append(ShapedType.get_dynamic_size()) dynamic_sizes.append(_get_op_result_or_value(size)) - sizes_attr = ArrayAttr.get(static_sizes) + sizes_attr = DenseI64ArrayAttr.get(static_sizes) num_loops = sum( v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) @@ -266,14 +272,14 @@ _get_op_result_or_value(target), dynamic_sizes=dynamic_sizes, static_sizes=sizes_attr, - interchange=_get_int_array_attr(interchange) if interchange else None, + interchange=_get_dense_int64_array_attr(interchange) if interchange else None, loc=loc, ip=ip) - def __extract_values(self, attr: Optional[ArrayAttr]) -> List[int]: + def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]: if not attr: return [] - return [IntegerAttr(element).value for element in attr] + return [element for element in attr] class VectorizeOp: diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -138,7 +138,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - transform.structured.interchange %0 { iterator_interchange = [1, 2, 0]} + transform.structured.interchange %0 {iterator_interchange = [1, 2, 0]} } // CHECK-LABEL: func @permute_generic @@ -191,8 +191,8 @@ transform.sequence failures(propagate) { ^bb0(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange=[1, 2, 0]} - %2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange=[1, 0, 2]} + %1, %loops:3 = transform.structured.tile %0 [2000, 3000, 4000] {interchange = [1, 2, 0]} + %2, %loops_2:3 = transform.structured.tile %1 [200, 300, 400] {interchange = [1, 0, 2]} %3, %loops_3:3 = transform.structured.tile %2 [20, 30, 40] } diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -108,7 +108,6 @@ # CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1 # CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3 - @run def testTileCompact(): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) @@ -120,14 +119,11 @@ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] # CHECK: interchange = [0, 1] - @run def testTileAttributes(): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) - attr = ArrayAttr.get( - [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [4, 8]]) - ichange = ArrayAttr.get( - [IntegerAttr.get(IntegerType.get_signless(64), x) for x in [0, 1]]) + attr = DenseI64ArrayAttr.get([4, 8]) + ichange = DenseI64ArrayAttr.get([0, 1]) with InsertionPoint(sequence.body): structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange) transform.YieldOp() @@ -136,7 +132,6 @@ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8] # CHECK: interchange = [0, 1] - @run def testTileZero(): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get()) @@ -149,7 +144,6 @@ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0] # CHECK: interchange = [0, 1, 2, 3] - @run def testTileDynamic(): with_pdl = transform.WithPDLPatternsOp(pdl.OperationType.get())