diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td --- a/mlir/include/mlir/Dialect/AMX/AMX.td +++ b/mlir/include/mlir/Dialect/AMX/AMX.td @@ -196,14 +196,14 @@ into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8" combinations (4 bytes packed into dwords in the columns of both the source operand tiles; the zero or sign extension is specified with - the attributes). The operation is eventually lowered into one of - the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud" instructions with - the corresponding tile configuration. + the attributes and default to sign extended). The operation is eventually + lowered into one of the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud" + instructions with the corresponding tile configuration. Example: ```mlir - %0 = amx.tile_muli %a, %b, %c [true, true] + %0 = amx.tile_muli %a zext, %b zext, %c : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> ``` }]; @@ -211,7 +211,9 @@ let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs, VectorOfRankAndType<[2], [I32, I8]>:$rhs, VectorOfRankAndType<[2], [I32, I8]>:$acc, - BoolArrayAttr:$zext); + UnitAttr:$isZextLhs, + UnitAttr:$isZextRhs + ); let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res); let extraClassDeclaration = [{ VectorType getLhsVectorType() { @@ -224,7 +226,7 @@ return res().getType().cast(); } }]; - let assemblyFormat = "$lhs `,` $rhs `,` $acc $zext attr-dict `:` " + let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` " "type($lhs) `,` type($rhs) `,` type($acc) "; } diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -85,8 +85,6 @@ } static LogicalResult verify(amx::TileMulIOp op) { - if (op.zext().size() != 2) - return op.emitOpError("unexpected zext length"); VectorType aType = op.getLhsVectorType(); VectorType bType = op.getRhsVectorType(); VectorType cType = op.getVectorType(); diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp @@ -191,8 +191,8 @@ getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); // Replace operation with intrinsic. Type resType = typeConverter->convertType(cType); - bool zexta = op.zext()[0].cast().getValue(); - bool zextb = op.zext()[1].cast().getValue(); + bool zexta = op.isZextLhs(); + bool zextb = op.isZextRhs(); if (zexta && zextb) rewriter.replaceOpWithNewOp( op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir --- a/mlir/test/Dialect/AMX/invalid.mlir +++ b/mlir/test/Dialect/AMX/invalid.mlir @@ -46,13 +46,3 @@ // expected-error@+1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}} %3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32> } - -// ----- - -func @zextsize() { - %0 = amx.tile_zero : vector<8x8xi8> - %1 = amx.tile_zero : vector<8x8xi8> - %2 = amx.tile_zero : vector<8x8xi32> - // expected-error@+1 {{'amx.tile_muli' op unexpected zext length}} - %3 = amx.tile_muli %0, %1, %2 [true] : vector<8x8xi8>, vector<8x8xi8>, vector<8x8xi32> -} diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir --- a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir +++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir @@ -17,13 +17,13 @@ %1 = amx.tile_zero : vector<16x64xi8> %2 = amx.tile_load %arg0[%0, %0] : memref into vector<16x64xi8> %3 = amx.tile_load %arg1[%0, %0] : memref into vector<16x16xi32> - %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> amx.tile_store %arg1[%0, %0], %4 : memref, vector<16x16xi32> - %5 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + %5 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> amx.tile_store %arg1[%0, %0], %5 : memref, vector<16x16xi32> - %6 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> amx.tile_store %arg1[%0, %0], %6 : memref, vector<16x16xi32> - %7 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + %7 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> amx.tile_store %arg1[%0, %0], %7 : memref, vector<16x16xi32> return } diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir --- a/mlir/test/Dialect/AMX/roundtrip.mlir +++ b/mlir/test/Dialect/AMX/roundtrip.mlir @@ -28,14 +28,22 @@ // CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x64xi8> // CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x64xi8> // CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x16xi32> -// CHECK: %[[m:.*]] = amx.tile_muli %[[x]], %[[y]], %[[z]] [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> +// CHECK: %[[m:.*]] = amx.tile_muli %[[x]] zext, %[[y]] zext, %[[z]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> // CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref, vector<16x16xi32> +// Verify the parsing/printing of the sign-extension annotation.Dialect/AMX/roundtrip.mlir +// CHECK: amx.tile_muli %{{.*}}, %{{.*}} zext, %{{.*}} +// CHECK: amx.tile_muli %{{.*}} zext, %{{.*}}, %{{.*}} +// CHECK: amx.tile_muli %{{.*}}, %{{.*}}, %{{.*}} func @tmuli(%arg0: memref, %arg1: memref, %arg2: memref) { %0 = constant 0 : index %1 = amx.tile_load %arg0[%0, %0] : memref into vector<16x64xi8> %2 = amx.tile_load %arg1[%0, %0] : memref into vector<16x64xi8> %3 = amx.tile_load %arg2[%0, %0] : memref into vector<16x16xi32> - %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + %4 = amx.tile_muli %1 zext, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> amx.tile_store %arg2[%0, %0], %4 : memref, vector<16x16xi32> + // verify the various `zext` combinations. + %5 = amx.tile_muli %1, %2 zext, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + %6 = amx.tile_muli %1 zext, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + %7 = amx.tile_muli %1, %2, %3 : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> return }